1 Introduction

Self-supervised learning (SSL) has gained popularity as a learning technique for acquiring meaningful representations in an unsupervised manner. By training on large unlabeled datasets using self-supervised pretext and auxiliary tasks, SSL produces features that can efficiently be applied to downstream tasks with fewer labels, as demonstrated in Ahmed et al. (2021), Caron et al. (2021), Assran et al. (2022), He et al. (2021), Xie et al. (2021), Atito et al. (2021). Furthermore, SSL has enabled Vision Transformers (ViTs) (Dosovitskiy et al., 2020) to outperform Convolutional Neural Networks (CNNs) in various image-related tasks, including classification, detection, and segmentation (Ahmed et al., 2021; Xie et al., 2021; He et al., 2021; Zhou et al., 2021).

Fig. 1
figure 1

Accuracy of different methods on low-shot classification task for 1, 2, 5 images, and 1% images per label of ImageNet-1K

Self-supervised methods for modelling global discriminative features often employ either contrastive pretext tasks (Chen et al., 2020a, b, c, 2021; He et al., 2019) or clustering pretext tasks (Caron et al., 2021, 2020; Assran et al., 2022). On the other hand, an alternative avenue of SSL, i.e. Masked Image Modeling (MIM) methods, has emerged with a distinct focus on capturing contextual information by reconstruction either at the pixel level (Atito et al., 2022; Xie et al., 2021; He et al., 2021) or the token level (Wei et al., 2021; Bao et al., 2021), thereby lacking the incorporation of discriminative details that are crucial for generating globally informative features. Jiang et al. (2023) explored a combination of contrastive pretext tasks and MIM in the pixel space, which has shown to yield improved representations.

iBoT (Zhou et al., 2021) follows a similar approach, but replace the contrastive task with a clustering task at both global and patch levels and utilise MIM for token-level masked region prediction. Further, MSN (Assran et al., 2022) proposes ME-MAX loss instead of the centring trick proposed in DINO (Caron et al., 2021) for better low-shot linear evaluation.

All previous methods have demonstrated strong performance in the large dataset size regime. However, their performance in low-shot scenarios has been largely overlooked, with the notable exception of the MSN approach. However, MSN do not analyse the effect of different SSL components, like the choice of the pretext task and the collapse avoidance mechanisms for low-shot learning. Moreover, MSN confines its evaluation to low-shot linear classification on ImageNet-1K, which presents several limitations. For one, pre-training with a contrastive or clustering loss function typically results in high linear evaluation performance, especially when the model is assessed using the same dataset on which it was pre-trained. This outcome can falsely imply effectiveness in low-shot scenarios. Additionally, this narrow method of assessment does not accurately predict the model’s transferability to different tasks and datasets. Hence, there’s a notable gap in the literature in comprehensive, system-level analyses of the impact of SSL and its components on low-shot applications.

In this paper, we perform a detailed study of the impact of pretext tasks i.e. clustering, contrastive learning, and MIM, and the choice of a collapse avoidance method, i.e ME-MAX, sinkhorn and centring on the performance of low-shot downstream tasks. In addition, we also study the effect of extending the instance discrimination pretext task to the patch level. We provide an overview of different pretext tasks and collapse avoidance mechanisms used by previous frameworks in Table 1.

Based on the above analysis, we investigate a simple model with a combination of two different pretext tasks namely clustering and MIM for low-shot learning. Clustering is done at both, the class token level to capture global semantics and the patch level to capture local semantics. We perform MIM at pixel level, in addition to clustering, to capture finegrained details. When evaluated on several low-shot downstream tasks namely multi-label classification, multi-class classification and semantic segmentation, the proposed simple model works better due its ability to capture details at various levels. We also present the performance of state-of-the-art self-supervised models on these downstream tasks. Figure 1 shows the performance comparison of various SSL methods in low-shot classification on ImageNet-1K. To analyse the scaling behaviour on full datasets we finetune the model on standard finetuning evaluation settings following the previous SSL approaches (Caron et al., 2021; Atito et al., 2021; Zhou et al., 2021). We find that our model performs favourably in these settings as well.

Table 1 A review of different self supervised methods and their pretext tasks and collapse avoidance mechanisms

2 Related Works

The early SSL methods in computer vision relied on simple pretext tasks, such as solving jigsaw puzzles (Noroozi & Favaro, 2016), predicting colour from grayscale images (Xie et al., 2016), or classifying relative positions (Doersch et al., 2015). However, recent advances introduced more sophisticated pretext tasks with complex training objectives. Generative methods, that mask parts of the input randomly and predict those regions at the pixel or token level, gained popularity in SSL (Ahmed et al., 2021; Vincent et al., 2010; Chen et al., 2020d; Atito et al., 2022; Xie et al., 2021; He et al., 2021). GMML in SiT (Ahmed et al., 2021) was the first ViT method to demonstrate that masked autoencoder, i.e., masking randomly large proportions of image patches and reconstructing them, leads to strong self-supervised pretext task capable of outperforming supervised pretraining. BeiT (Bao et al., 2021) extended the idea of masked autoencoder using a discrete variational autoencoder (dVAE) for token generation and prediction. SimMIM (Xie et al., 2021) and MAE (He et al., 2021) employed the idea of heavy masking and recovery of information to a larger scale using an autoencoder-style approach for pixel-wise reconstruction. GMML, MAE, and SimMIM reconstruct at pixel level, whereas BeiT reconstructs at token level. These methods do not enforce global level representation consistency across different views of the same image.

Contrastive methods, on the other hand, learn invariance by emphasising similarity between positive views and reducing similarity between negative views using InfoNCE loss (Oord & Vinyals, 2018). SimCLR (Chen et al., 2020a) highlighted the importance of data augmentation, while MoCo (He et al., 2019) introduced a memory bank to address the issue of large batch size. SiT (Ahmed et al., 2021) combined MIM with contrastive learning, leading to performance improvements. Clustering-based methods achieved invariance by learning similar cluster assignments for different augmented views. SwAV (Caron et al., 2020) used cluster assignments as a supervisory signal, and DINO (Caron et al., 2021) emphasised the role of momentum encoder and multiple crops for SSL. Collapse is a major issue for self supervised methods producing trivial solutions, where embeddings do not have enough variance in the representation space. Recent methods have used asymmetry in design (Chen & He, 2020), sinkhorn to normalise the teacher cluster assignments (Caron et al., 2020), momentum encoders to generate target embeddings (He et al., 2019), centring to make teacher distribution more uniform along with sharpening (Caron et al., 2021) to avoid collapse. MSN (Assran et al., 2022) is a variant of DINO, where collapse is avoided with ME-MAX loss instead of centring or sinkhorn, showing superior performance in low-shot linear evaluation. By default they apply ME-MAX with sinkhorn to avoid setting the scaling factor for ME-MAX loss. iBoT (Zhou et al., 2021) extends DINO with masking and clustering applied to both patch and class tokens, yet iBoT lacks fine-grained context due to lack of pixel reconstruction.

Recent works have established benchmarks for evaluating SSL methods with in natural image, speech, medical domain. For instance, a systematic analysis of SSL pretraining methods was conducted, evaluating their scalability across multiple datasets and tasks, including object detection and segmentation (Goyal et al., 2019). However, this study primarily focuses on general downstream performance without addressing low-shot learning scenarios. Similarly, DABS, a domain-agnostic benchmark for SSL, was proposed, with evaluations spanning diverse domains such as natural images and speech (Tamkin et al., 2021). This benchmark, however, does not explore performance in extreme low-data regimes. Another investigation into SSL methods for pathology images highlighted their efficacy in medical imaging tasks (Kang et al., 2022), yet the focus remains on standard downstream evaluations rather than low-shot settings. These benchmarks illustrate the progress made in SSL evaluation, but none explicitly address the critical challenge of low-shot learning, underscoring the motivation for focusing on this underexplored area.

Few-shot learning has been extensively studied over the years, particularly in classification and segmentation tasks. Prototypical Networks learn a metric space where classification is achieved by computing distances to prototype representations for each class (Snell et al., 2017). Relation Networks leverage neural networks to learn a similarity metric specifically designed for few-shot classification (Sung et al., 2018). Model-Agnostic Meta-Learning (MAML) enables rapid adaptation of models to new tasks with limited data (Finn et al., 2017). An optimization-based meta-learning framework was proposed, laying the foundation for several subsequent few-shot methods (Ravi & Larochelle, 2017). In segmentation, OSLSM enables adaptation to novel classes with a minimal number of examples (Shaban et al., 2017). CANet improves segmentation performance through iterative refinement of predictions (Zhang et al., 2019). The FSS-1000 dataset provides pixel-wise annotations and serves as a rigorous benchmark for evaluating few-shot segmentation models (Li et al., 2021). PANet incorporates prototype alignment to enhance segmentation performance in few-shot settings (Wang et al., 2019). Feature-Proxy Transformer (FPTrans) introduces a novel architecture for few-shot segmentation by leveraging proxy representations to better capture class-specific features with minimal examples, further advancing the field (Zhang et al., 2022). Despite these advancements, none of these works explicitly study the effects of self-supervised learning (SSL) on few-shot classification and segmentation tasks. This gap motivates our work, which investigates the impact of SSL pretraining on low-shot finetuning performance across diverse datasets and tasks.

Semi-supervised methods (Lucas et al., 2022; Sohn et al., 2020) have considered extreme low data scenarios which have been mostly overlooked by SSL community with the exception of MSN (Assran et al., 2022). Moreover, the existing literature lacks a comprehensive examination of the impact of various SSL components on low-shot learning performance. While MSN demonstrates superior low-shot linear evaluation through ME-MAX loss utilization, it falls short in exploring the influence of diverse pretext tasks and collapse avoidance mechanisms on various low-shot downstream tasks. Additionally, their exclusive focus on low-shot linear evaluation using the pretraining dataset (ImageNet-1K) may not generalize effectively to different datasets and downstream tasks. Hence, our emphasis is on low-shot finetuning across diverse tasks and datasets. We perform thorough analysis of different components including the choice of pretext task, choice of collapse avoidance. Motivated by our findings we propose a method which has multiple pretext tasks: clustering and masked image modelling. The introduced model applies clustering on both class and patch tokens and does reconstruction with a pixel level loss. When compared to other SSL methods we find that our model performs the best across several low-shot downstream tasks. The performance also scales to large scale when performing finetuning on full datasets.

Fig. 2
figure 2

An overview of MaskCluster architecture for low-shot. Multiple global views are generated which are then masked to generate masked global crops. Unmasked global views are processed by a teacher encoder, which consists of multiple layers of transformer to generate output embeddings \({\textbf{Z}}^{e, G}_{c}\), \({\textbf{Z}}^{e, G}_{p}\) corresponding to class token and patch tokens respectively. The teacher cluster layers \(CL_t\) and \(PL_t\) generate class and patch cluster assignments \({\textbf{P}}^{G}_{c}, {\textbf{P}}^{G}_{p}\) from embeddings \({\textbf{Z}}^{e, G}_{c}\), \({\textbf{Z}}^{e, G}_{p}\) respectively. Similar procedure is followed by the student to generate output embeddings \(\bar{\textbf{Z}}^{e}_{c}, \bar{\textbf{Z}}^{e, G}_{p}\) for masked global crops, which are used to generate cluster assignments \(\bar{\textbf{P}}^{G}_{c}, \bar{\textbf{P}}^{G}_{p}\) by student clustering layers \(CL_s\) and \(PL_s\) for class and patch tokens respectively. Further, the masked global output patch embeddings \(\bar{\textbf{Z}}^{e, G}_{p}\) are reconstructed back to the pixel space, with the help of the reconstruction head attached to the student. Finally, the cross entropy loss is applied between the teacher and student cluster assignments and \(\ell \)1-loss for reconstruction between the original view and the reconstructed image

3 Details of SSL Components and Analysis

We aim to provide a detailed analysis of the effect of the choice of pretext tasks, and the choice of a collapse avoidance mechanism on low-shot downstream tasks. In addition we propose an architecture, which is based on the findings of the study. We believe that the choice of pretext task produces a huge impact when finetuning on low-shot tasks. Generally, clustering/contrastive learning only focuses on instance level discrimination and can be classified as an instance discrimination tasks. Focusing on a single global instance might not be beneficial in low data regimes. We believe that finegrained contextual information is necessary for the model to perform better in low data regimes. Collapse avoidance also plays a crucial role when evaluating the self supervised model on low-shot data (Assran et al., 2022). We therefore present a study on the effects on collapse avoidance like ME-MAX (Assran et al., 2022), sinkhorn (Caron et al., 2020) and centring (Caron et al., 2021).

In addition we also answer the question of application of instance discrimination at class token level or to both class and patch level. iBoT (Zhou et al., 2021) shows that application of instance discrimination at both and class level helps for finetuning of full scale data. We avoid the study of different architectures and stick to vision transformer particulary ViT-S (Dosovitskiy et al., 2020) for fair comparision. Table 1 presents an overview of different components used by previous SSL frameworks. Based on the detailed study we introduce a low-shot capable self supervised model which also scales to large scale datasets.

3.1 Introduction of Different SSL Pretext Tasks

We provide a brief introduction to different pretext tasks and their formulation. In this study we focus on contrastive learning, clusering and masked image modelling with pixel reconstruction as the main pretext tasks. Instance discrimination pretext tasks like clustering and contrastive learning are better at learning semantics. MIM based pretext task which reconstruct masked image at pixel level learn local context. We define theese tasks in the context of vision transformer (Dosovitskiy et al., 2020).

Contrastive Learning: Contrastive learning introduced for self supervision in SimCLR (Chen et al., 2020a) generally has 2N augmented data points generated from N original images where each image generates two random augmented views. Let \({\textbf{z}}_i\) be the output embeddings generated from \(i_{th}\) data point after passing through a vision transformer with projection head attached (Chen et al., 2020a). If \({\textbf{z}}_i\), \({\textbf{z}}_j\) are the embeddings of positive pairs then the contrastive loss is provided by Eq. 1.

$$\begin{aligned} {\mathcal {L}}_{con}(i, j) = -log\frac{exp(sim(z_{i}, z_{j})/\tau )}{\sum ^{2N}_{k=1}\mathbbm {1}_{[k\ne i]}exp(sim(z_{i}, z_{j})/\tau )} \end{aligned}$$
(1)

Here loss \(\ell _{con}(i, j)\) is for a single positive pair (ij) and it will be applied to all the positive pairs including (ji). \(\mathbbm {1}_{[k\ne i]} \in \{0, 1\}\) is a function that indicates if \(k\ne 1\).

Clustering: Clustering is a negative sample free pretext task which makes it less sensitive to the choice and number of negative samples within the batch. It also does not require large batch sizes or memory bank used in contrastive learning (Chen et al., 2020a, 2021). We generate two random global augmented views \({\textbf{x}}_{g1}\), \({\textbf{x}}_{g2}\) from an input image \({\textbf{x}}\). (Caron et al., 2021; Zhou et al., 2021) also use several local crops in addition to global crops but we skip them for the loss formulation for simplicity. Generally clustering requires a teacher and student where the teacher is an exponential moving average(EMA) of the student. Let \({\textbf{c}}_{g1}\), \({\textbf{c}}_{g2}\) be the predicted cluster assignments corresponding to class token of the student after passing through the projection head similar to DINO (Caron et al., 2021). Let \(\bar{\textbf{c}}_{g1}\), \(\bar{\textbf{c}}_{g2}\) be the target cluster assignments corresponding to class token of teacher after passing through its projection head. Then the class level clustering loss is defined as cross entropy between \({\textbf{c}}_{g1}\), \(\bar{\textbf{c}}_{g2}\) and between \({\textbf{c}}_{g2}\), \(\bar{\textbf{c}}_{g1}\). given by Eq. 2.

$$\begin{aligned} {\mathcal {L}}_{cc} = \frac{1}{2}(\text {H}({\textbf{c}}_{g1}, \bar{\textbf{c}}_{g2}) + \text {H}({\textbf{c}}_{g2}, \bar{\textbf{c}}_{g1})) \end{aligned}$$
(2)

iBoT (Zhou et al., 2021) considers extending the clustering loss to patches also for which we provide the formulation in equation A4. Let \({\textbf{p}}_{g1}^{i}\), \({\textbf{p}}_{g2}^{i}\) be the predicted cluster assignments corresponding to patch token \(i\in \{1 \dots N\}\) of student encoder after projection head where N is the number of tokens. Similarly teacher encoder generates target patch cluster assignments \(\bar{\textbf{p}}_{g1}^{i}\), \(\bar{\textbf{p}}_{g2}^{i}\).

$$\begin{aligned} {\mathcal {L}}_{pt} = \frac{1}{2N}\sum _{i}^{N}(\text {H}({\textbf{p}}_{g1}^i, \bar{\textbf{p}}_{g1}^i) + \text {H}({\textbf{p}}_{g2}^i, \bar{\textbf{p}}_{g2}^i)) \end{aligned}$$
(3)

Mask Image Modelling: Masked image modelling has been explored in SiT (Ahmed et al., 2021) for pixel level reconstruction which has been utilised by several following works (He et al., 2021; Xie et al., 2021). Pixel level reconstruction captures finegrained information required for low data regime (Atito et al., 2022). If x is an input image we generate a masked image \({\textbf{x}}_m\) with mask \({\textbf{M}}\) where \({\textbf{M}}(i, j)=1\) if the \({\textbf{x}}=(i, j)\) is masked. The masked image \({\textbf{x}}_m\) is passed through a vision transformer encoder and a light weight reconstruction head (Xie et al., 2021) that produces a reconstructed image \(\hat{\textbf{x}}\). The loss for reconstruction is \(\ell \)1-loss given in Eq. (4).

$$\begin{aligned} & {\mathcal {L}}_{mim} = \sum _{i}^{H} \sum _{j}^{W} {\textbf{M}}(i, j) \times |{\textbf{x}}(i, j) - \hat{\textbf{x}}(i, j)|) \end{aligned}$$
(4)
$$\begin{aligned} & {\textbf{M}} = {\left\{ \begin{array}{ll} 1 & \text {if }{\textbf{x}}(i, j) \, \text {is masked}\\ 0 & \text {otherwise} \end{array}\right. } \end{aligned}$$
(5)

3.2 Collapse Avoidance

Collapse avoidance has been majorly applied to clustering methods like (Caron et al., 2021; Assran et al., 2022; Zhou et al., 2021). We mainly study three different methods, namely centring explored in DINO (Caron et al., 2021), iBoT (Zhou et al., 2021), sinkhorn explored in SwaV (Caron et al., 2020) and ME-MAX loss introduced in MSN (Assran et al., 2022). We refer the reader to the details of these collapse avoidance to the above mentioned literature, but provide a formulation of ME-MAX since we also extend it to patches. MSN studies the effect of ME-MAX only on class token and by default combines it with sinkhorn. ME-MAX is also different from other methods where it is applied as a loss on student cluster assignments whereas both sinkhorn and centring can be considered as a normalisation applied to the teacher target cluster assignments. Let \(\bar{\textbf{c}}\) be the average cluster assignments corresponding to class token, \(\bar{\textbf{p}}\) be the patch level cluster assignments both generated from the student, we define the ME-MAX at class token level as Eq. 6 and for patch level as Eq. 7 where \({\textbf{H}}\) is the entropy function.

$$\begin{aligned} & {\mathcal {L}}_{mc} = - {\textbf{H}}(\bar{\textbf{c}}) \end{aligned}$$
(6)
$$\begin{aligned} & {\mathcal {L}}_{mp} = - {\textbf{H}}(\bar{\textbf{p}}) \end{aligned}$$
(7)

3.3 Analysis

We use ViT-S as our base architecture to study choice of pretext tasks, collapse avoidance on low-shot downstream tasks. We use low-shot classification on ImageNet-1K (Deng et al., 2009) to study the effect of these choices have on the performance. For 1-shot, 2-shot and 5-shot ImageNet-1K classification we utilise a standard dataset made available from MSN (Assran et al., 2022) where each of them have three different splits. All the experiments are pretrained for similar number of epochs and finetuned on target dataset with the accuracy reported for mean of three splits.

Table 2 Evaluation of pretext task on ImageNet-1K low-shot multi-class classification performance

Which pretext task to choose? We explore clustering and contrastive learning as instance discrimination pretext tasks, along with MIM for pixel reconstruction, emphasizing context. Comparative experiments, detailed in Table 2, reveal that clustering with ME-MAX loss outperforms contrastive learning. Focusing on a single semantic object is suboptimal for low-shot performance, as evidenced by inferior results in instance discrimination methods compared to masked image modeling. Clustering, unaffected by the choice of negative samples, surpasses contrastive learning. Combining clustering with MIM yields the best performance, underscoring the importance of fine-grained context and discriminative information for low-shot scenarios.

Table 3 Evaluation of collapse avoidance mechanisms on ImageNet-1K low-shot multi-class classification performance

Collapse avoidance makes a difference Collapse avoidance is done either through centring (Caron et al., 2021), ME-MAX (Assran et al., 2022) or sinkhorn (Caron et al., 2020). We evaluate the effects of all the above methods for collapse avoidance on low-shot evaluation performance in Table 3. We find that applying ME-MAX is better compared to sinkhorn and centring. Forcing the network to learn to use all the available clusters at the output through loss is helping the network in low-shot regime.

Table 4 Evaluation of class level or both class and patch level clustering on ImageNet-1K low-shot multi-class classification performance

Instance discrimination at patch level is needed? We assess the impact of patch-level instance discrimination on low-shot performance. Inspired by iBoT (Zhou et al., 2021), which demonstrates enhanced ImageNet-1K finetuning with patch clustering, we explore the potential benefits for low-shot scenarios. In Table 4, we analyze clustering applied solely at the class, patch level and at both class and patch levels. Given clustering’s superiority over contrastive learning (Table 2) and the effectiveness of clustering with ME-MAX (Table 3), applying clustering at both levels proves more effective for low-shot evaluation than at the class, patch level alone. Notably, patch clustering alone performs poorly due to the absence of global-level context, as it lacks class-level information provided by class token clustering. This emphasizes the significance of local discriminative information, enhancing network performance in downstream tasks.

3.4 Simple Pretext Combination for Low-Shot

Based on the above analysis we find that the capturing information at various levels is required for low-shot learning. Thus we introduce MaskCluster model which captures global and local semantics with clustering pretext task while also having local contextual information from MIM pretext task with pixel level reconstruction. The model has both student and teacher networks both based on vision transformer (Dosovitskiy et al., 2020). We attach a projection head similar to previous approaches (Ahmed et al., 2021; Zhou et al., 2021; Caron et al., 2021; Assran et al., 2022) to generate cluster assignments. The architecture overiew is present in Fig. 2. The design of pretext tasks is discussed in previous subsections and our MIM pretext task closely follows GMML (Atito et al., 2022). In addition our clustering is done through ME-MAX loss which shows a slight improvement in low-shot setting. Our total loss is provided in Eq. 8 which is the summation of clustering loss at patch and class level, ME-MAX loss on patch and class cluster assignments and finally the MIM reconstruction loss. Further implementation details are provided in supplementary.

$$\begin{aligned} {\mathcal {L}}_{total} = {\mathcal {L}}_{cc} + {\mathcal {L}}_{pt} + {\mathcal {L}}_{mim} + {\mathcal {L}}_{mc} + {\mathcal {L}}_{mp} \end{aligned}$$
(8)
Fig. 3
figure 3

An efficient implementation of attention computation with cross and self attention components to slightly improve pretraining time. Self attention is computed between queries \(Q_{u}\), keys \(K_{u}\) and values \(V_{u}\). Cross attention is calculated between queries \(Q_{m}\), keys \(K_{u}\), and values \(V_{u}\). Cross attention and self attention values are concatenated to generate the final output \(X_{attn}\)

3.5 A Simple Trick to Improve Pretraining Time

Most of the time complexity of the transformer is because of quadratic complexity of the attention. While performing attention in masked images, unmasked tokens need not attend to masked tokens as they lack any salient visual information and masked tokens need only information from unmasked tokens for reconstruction in MIM task. Motivated by this, we introduce an efficient attention illustrated in Fig. 3 for masked images that splits attention into two components: self-attention and cross-attention. Self attention given by Eq. 9 is done between unmasked tokens only, cross attention provided by Eq. 10 is performed between masked and unmasked tokens. Finally the total attention in Eq. 11 is the concatenation of these attentions. In above \({\textbf{Q}}_{u}\), \({\textbf{K}}_{u}\), \({\textbf{V}}_{u}\) represents the query, key and value from masked tokens. Similary \({\textbf{Q}}_{m}\) represents the query generated from masked tokens. The generation of all queries, keys and values follows different linear projection layers similar to ViT (Dosovitskiy et al., 2020). Table 14 highlights increased throughput at various masking ratios, demonstrating enhancements even on a low-end GPU (RTX 3060). This attention is applied to process student masked global crops.

$$\begin{aligned} & \text {SA} = (\text {Softmax}({\textbf{Q}}_{u}{\textbf{K}}_{u}^T/\sqrt{d})){\textbf{V}}_{u} \end{aligned}$$
(9)
$$\begin{aligned} & \text {CA} = (\text {Softmax}({\textbf{Q}}_{m}{\textbf{K}}_{u}^T/\sqrt{d})){\textbf{V}}_{u} \end{aligned}$$
(10)
$$\begin{aligned} & \text {EA = Concat(SA, CA)} \end{aligned}$$
(11)

4 Experimental Results

We conduct a comprehensive evaluation of the pretrained model’s performance in low-shot scenarios across three distinct downstream tasks: multi-class classification, multi-label classification, and semantic segmentation. Our assessment also employs the standard evaluation protocol using complete datasets. Specifically, we evaluate the model on various downstream tasks, including multi-class classification, multi-label classification, and semantic segmentation, as detailed in Sect. 4.2.

Performing both pretraining and linear evaluation on Imagenet-1K confers an unfair advantage to methods that effectively capture dominant semantic concepts, such as DINO (Caron et al., 2021), MSN (Assran et al., 2022), MoCo (He et al., 2019), and iBoT (Zhou et al., 2021). Additionally, the incorporation of pixel-level reconstruction (MIM) detrimentally affects linear evaluation performance on Imagenet-1K, as illustrated by MAE (He et al., 2021). To mitigate these biases, we opt to finetune on Imagenet-1K and perform linear evaluation on datasets encompassing multiple semantic concepts, namely PASCAL VOC, COCO, and VisualGenome.

We conduct ablation studies to dissect the individual components of our approach and analyze their respective contributions, as elaborated in Sect. 4.3. Furthermore, supplementary materials provide additional insights into training, finetuning details, and visualizations. For pretraining, we adhere to the methodology outlined in DINO (Caron et al., 2021), while for finetuning evaluation, we adopt DeiT (Touvron et al., 2020), making only necessary adjustments to learning rates and epochs. Our linear evaluation process follows the guidelines set forth by DINO (Caron et al., 2021), with tailored adjustments to learning rates for different datasets.

4.1 Pre-training Setup

We use ViT-S/16 as the base backbone architecture pretrained on ImageNet-1K (Deng et al., 2009). Our approach follows the conventional multi-crop method, with two global crops of resolution \(224 \times 224\), and ten local crops of resolution \(96 \times 96\). The class, patch clustering layers have a dimension of 8192. We pretrain the model for 800 epochs with a \(50\%\) masking ratio, following GMML masking strategy (Atito et al., 2022). We use the AdamW optimiser (Loshchilov & Hutter, 2017) with weight decay of 0.05, learning rate of \(5e^{-4}\), gradient clipping threshold of 3.0, and 15 warmup epochs.

Table 5 The low-shot performance evaluation on subset of ImageNet-1K. All the models are pretrained on ImageNet-1K dataset. We report mean top-1 accuracy of three different splits

4.2 Main Results

Low-shot Multi-class Classification. To assess label efficiency, we fine-tune our model on smaller subsets of ImageNet-1K (Deng et al., 2009), utilizing MSN’s (Assran et al., 2022) data subsets for 1, 2, or 5 images per label. Each of these datasets have three different splits and we evaluate on all the splits while reporting the mean accuracy. Furthermore, we finetune on 1% of ImageNet-1K (Deng et al., 2009) based on the SimCLR (Chen et al., 2020a) split. Results (Fig. 1) highlight our approach’s substantial performance lead over current state-of-the-art techniques, particularly under extremely limited data during fine-tuning. For instance, with only 1 image per label, our method attains 22.4% accuracy, surpassing the state-of-the-art by 14.4%. Suprisingly we also perform better when compared to MAE (He et al., 2021) with much larger ViT-B as the encoder. The results are provided in Table 5.

Table 6 Pascal 5i low-shot segmentation results. Comparison of few-shot and self-supervised methods used to initialize ViT-S/16 in FPTrans. FPTrans (ViT init.) uses ImageNet-1K supervised weights. s0, s1, s2, s3 represent different splits

Low-shot Semantic Segmentation. To assess low-shot segmentation performance, we employ FPTrans (Zhang et al., 2022), a few-shot segmentation framework tailored for vision transformers. Table 6 presents traditional few-shot frameworks for semantic segmentation under low-shot settings, all of which initialize their backbones with ImageNet-1K supervised weights. We extend this evaluation to include self-supervised models by using their weights for initialization, including those of our proposed method. The evaluation is conducted on the Pascal 5i (Shaban et al., 2017) dataset under both one-shot and five-shot scenarios. Specifically, we initialize the ViT-S/16 backbone within FPTrans using pre-trained weights from self-supervised learning (SSL) methods and adopt the training settings specified in FPTrans. For comparative analysis, FPTrans is also trained using ImageNet-1K supervised weights for ViT-S/16. Our model surpasses other self-supervised approaches and delivers competitive performance compared to few-shot segmentation methods in both one-shot and five-shot settings (Table 6). These results highlight the remarkable effectiveness of our approach for low-shot semantic segmentation.

Table 7 Low-shot multi-label classification on Pascal VOC and mini COCO with mAP as a metric. All SSL models are pretrained on ImageNet-1K, employing ViT-S/16 backbone
Table 8 Linear evaluation of low-shot multi-label classification on Pascal VOC and mini COCO with mAP as a metric

Low-shot multi-label Classification. To assess the performance of our model in low-shot multi-label classification, we created datasets with 1, 2, and 5 images per label through random sampling with a fixed seed. Additionally, we included the Mini.COCO (Samet et al., 2020) dataset for evaluation, which contains only 20% of MS-COCO (Lin et al., 2014) training data. All models were trained under similar conditions using a resolution of \(224 \times 224\). Results are summarized in Tables 7 and 8, where models were assessed on full validation sets. Table 7 presents performance after fine-tuning on the low-shot target dataset, showing our method outperforming previous approaches akin to our multi-class results. Table 8, employing linear evaluation protocol, demonstrates similar trends to fine-tuning. Notably, linear evaluation yields superior results for 1 image from Pascal VOC (Everingham et al., 2010) compared to fine-tuning. As training data increases, fine-tuning performs competitively with linear evaluation. Interestingly, MSN’s performance matches ours in linear evaluation with 5 images from Pascal VOC. Our model surpasses all SSL methods, showcasing exceptional labeling efficiency in multi-label classification tasks.

Multi-class Classification. We evaluate the performance of our method by finetuning pretrained weights from ImageNet-1K on various downstream datasets (Table 9). Our model achieves \(82.1\%\) top-1 accuracy on ImageNet-1K (Deng et al., 2009), which is comparable to state-of-the-art methods. Additionally, we showcase the effectiveness of our approach using the results obtained by finetuning on smaller datasets like Flowers (Nilsback , Zisserman, 2008), Cars (Krause et al., 2013), Pets (Parkhi et al., 2012), CIFAR-10 (Krizhevsky, 2009), and CIFAR-100 (CIFAR, 2009). Across all these datasets, our method consistently outperforms state-of-the-art approaches, demonstrating the strong transferability of our proposed framework.

Table 9 Transfer learning by finetuning pretrained models with the ViT-S/16 backbone on diverse datasets. We report top-1 accuracy

Multi-label Classification. Our model is fine-tuned on Pascal VOC (Everingham et al., 2010), MS-COCO (Lin et al., 2014), and VisualGenome (Krishna et al., 2016) datasets (224x224 input) for 100 epochs. In Table 10, our method excels on all datasets, showcasing its multi-semantics capture capability. Moreover, in linear settings (training only classification layer), our model surpasses others on Pascal VOC and Visual Genome datasets while maintaining strong performance on COCO (Table 11), which outlines the results of the multi-label linear evaluation, further emphasizing our method’s prowess in multiple semantics capture, even in linear scenarios.

Table 10 Transfer learning of ViT-S on multi-label datasets with mAP as the evaluation metric

Semantic Segmentation. Our model is evaluated on the Pascal VOC (Everingham et al., 2010) dataset using two distinct settings: linear mode and finetuning mode. In the linear mode, only the linear layer is trained, while keeping the backbone frozen. Conversely, in the finetuning mode, the entire network, including the UPerNet (Xiao et al., 2018) task layer, is trained. UPerNet (Xiao et al., 2018) is a decoder introduced for semantic segmentation, which has been utilized for evaluation on semantic segmentation by other self-supervised methods (Caron et al., 2021; Zhou et al., 2021; He et al., 2021).

The results, outlined in Table 12, unequivocally illustrate the exceptional performance of our approach in the task of image segmentation in both linear and finetuning settings. These findings underscore the effectiveness of our model in capturing intricate details and accurately segmenting objects within images.

Table 11 Linear evaluation on multi-label datasets with mAP as the evaluation metric. All the models are pretrained on ImageNet-1K employing the ViT-S/16 backbone
Table 12 Performance of semantic segmentation on Pascal VOC
Table 13 Ablation study on losses with ViT-S for multi-class (1Img per label ImageNet-1K), low-shot segmentation (Pascal 5i), and multi-label classification (Mini COCO, Pascal VOC 1-Shot)

4.3 Ablation Study

Table 13 summarizes our findings on reconstruction loss and clustering at class or both class and patch levels, assessing low-shot classification and segmentation tasks. Removing the reconstruction loss notably harms performance, especially in 1-Shot ImageNet-1K, underscoring its importance for capturing fine-grained information in the low-shot regime. Combining patch-level clustering with class-level clustering further enhances performance. Additionally, we evaluate impact of using efficient attention, observing increased throughput with higher mask ratios in Table 14, enabling reduced pretraining time.

4.4 Visualisation

In this section, we provide visualisations to showcase the learning of the patch and class clustering layers.

Class Clustering Visualisations: Figure 4 displays the top 9 images from each class cluster, showcasing unsupervised grouping of visually similar objects. Notably, the class clusters focus on butterflies, cars, cellphones, and fruits.

Patch Clustering Visualisation: We visualize the top 9 patches with the highest confidence for the selected clusters, along with their \(5\times 5\) neighborhood, in Fig. 5. These patterns demonstrate that our network captures both semantic information, such as heads and poles, as well as texture details at patch level.

Table 14 Throughput for different percentage of masking using the proposed efficient attention on RTX 3060 GPU
Fig. 4
figure 4

Visualisation of the top 9 images from different class clusters. The left most cluster focuses on butterflies, the next one on cars, the third column shows cellphones and the last one fruits

Fig. 5
figure 5

The top 9 patches from different patch clusters with their \(5\) surrounding patches. The left most patch cluster focuses on insect legs, the next one on the pole and remaining on textures

5 Conclusion

In this study, we examine how different pretext tasks and collapse avoidance strategies impact low-shot performance. Clustering surpasses contrastive learning, especially for instance discrimination tasks. Combining reconstruction with instance discrimination, particularly clustering, improves low-shot performance. Further enhancements occur when applying clustering at both class and patch levels. Based on these insights, we propose a multi-level architecture that uses clustering for global and local semantics, along with reconstructing masked images. This architecture excels in low-shot downstream tasks and scales to full dataset fine-tuning across multiple tasks. Future work aims to extend this model’s low-shot performance to multi-modal settings.