How do you make your foundation model equivariant and robust to known transformations without re-training from scratch?
Deep learning has witnessed tremendous growth in the past decade. Still, as we strive for more nuanced understanding and performance improvements, one challenge emerges clearly: how do we ensure our models understand data transformations? Enter equivariance, an idea that can help our networks maintain consistent behaviour with data transformations. But with the rise of large pretrained models, how do we make them equivariant without changing their architecture or retraining the model from scratch with data augmentation? In this blogpost, we delve into ideas presented in the paper “Equivariant Adaptation of Large Pretrained Models”
Equivariant networks
The beauty of this is that such networks lead to more accurate, robust predictions and need fewer samples to train – this is great in theory but hard to implement in practice, especially for large pretrained models whose equivariant counterparts are not trivial to design or are very expensive to re-train from scratch. These massive models pretrained on the entire internet are extremely good at solving and reasoning about different tasks and are called foundation models
A recent alternative to designing equivariant networks was proposed by Kaba et al. canonical form
. This way, our task prediction network can work on this standardized format, ensuring consistency. This process involves adding an additional inexpensive network called the canonicalization network
or $c$, which learns to standardize the input. In our formulation, for an input $x$, the output from canonicalization network is $c(x) = g$, where $g$ denotes the group element corresponding to the orientation of $x$. The primary network that learns to solve the task based on the standardized input is called the prediction network
or $\phi$. In this particular formulation, achieving equivariance requires only ensuring that the canonicalization process is invariant to the transformation of the input. This means no matter which orientation you see the input, the canonicalization process should always bring it back to the same canonical orientation. This is achieved by using a shallow and cheap equivariant architecture for the canonicalization network. (see
Finally, the combination of the canonicalization network and the prediction network can be represented as $\Phi$:
\[\Phi(x) = c(x) \circ \phi(c(x)^{-1}. x)\] \[\Rightarrow \Phi(g. x) = c(g. x) \circ \phi(c(g. x)^{-1}. g. x)\] \[\Rightarrow \Phi(g. x) = g.c(x) \circ \phi(c(x)^{-1}. x) = g \circ \Phi(x)\]The beauty of this approach lies in how the canonicalization network separates the equivariance requirement from the core prediction network architecture. This means that you have the flexibility to employ any powerful pretrained large neural network for the main prediction task.
Sounds straightforward? Well, it has a hitch.
The main challenge is ensuring the canonicalization network ‘plays nice’ with the prediction network. For example, the canonicalization network can output orientations that hurt the training of the prediction network, leading to poor task performance. This becomes more important when the prediction network is pretrained on a certain dataset. For instance, if the canonicalization network transforms all images to be upside-down, but our pretrained prediction network wasn’t trained on upside-down images, the whole system falls apart. So, it’s vital that the canonicalization network outputs orientations of the data that is in-distribution for the pretrained prediction network.
The magic lies in designing our canonicalization function not just to transform data but to do so while being aware of how our prediction model was initially trained. The key is ensuring that the data being transformed (or standardized) is done to align with what the pretrained prediction model expects. Mathematically, the goal is to bring the predicted out-of-distribution orientations to the distribution of orientations the pretrained prediction network has seen.
In simple terms, it’s a guiding force ensuring that our canonicalization function behaves and produces output that the pretrained prediction network would expect and appreciate. We leverage the idea that our data can provide hints on the ‘typical’ transformations it undergoes. By encoding this into a prior, one can guide our canonicalization function to produce transformed data that’s not just standardized but also aligned with what the prediction network was trained on.
While mathematical and intricate, this entire process can be boiled down to ensuring that the large pretrained prediction network always looks at in-distribution samples. This results in a highly robust model that can confidently handle varied transformations in the input data, giving accurate predictions every time.
This section highlights the effectiveness of the approach for image classification and instance segmentation tasks. Additional results and experiments including point cloud classification and part segmentation are detailed in
The authors select Vision Transformer (ViT)
The authors compare different fine-tuning setups. First, Vanilla indicates the standard fine-tuning on the downstream dataset. C8-Aug. indicates fine-tuning on the downstream dataset and \(C_8\) group data augmentations. LC is the learned canonicalization approach proposed in Kaba et. al.
Furthermore, the authors scale this idea to large foundation models like the Segment Anything Model (SAM)
Finally, to facilitate the ideas discussed on equivariant adaptation of large-scale models, an open-source package Equiadapt is available from the authors.
In the ever-evolving world of AI and deep learning, it is critical to ensure models are robust and aware of symmetries. By learning to smartly transform our input data so that they are in the correct orientation for the pretrained models, we can create large-scale models that are powerful and aware of data transformations, bringing us a step closer to AI systems that understand the world as we do. As research into scaling continues, the fusion of large foundational models with equivariant adaptation techniques such as the one presented in this blogpost has the potential to emerge as a fundamental approach in enhancing the consistency and reliability of AI systems.