Review — CompRess: Self-Supervised Learning by Compressing Representations

CompRess, Distilling Large Model to Small Model in SSL

Sik-Ho Tsang
5 min readMay 17, 2023
Proposed CompRess Models (Red) Otuperforms SOTA SSL Models (Blue)

CompRess: Self-Supervised Learning by Compressing Representations,
CompRess, by University of Maryland,
2020 NeurIPS, Over 50 Citations (Sik-Ho Tsang @ Medium)

Self-Supervised Learning
1993 … 2022
[BEiT] [BEiT V2] [Masked Autoencoders (MAE)] [DiT] [SimMIM]

  • A model compression method, CompRess, is proposed to compress an already learned, deep self-supervised model (teacher) to a smaller one (student).
  • The student model is trained so that it mimics the relative similarity between the datapoints in the teacher’s embedding space.
  • This is the first time a self-supervised AlexNet has outperformed supervised one on ImageNet classification.

Outline

  1. CompRess
  2. Results
  3. Impacts

1. CompRess

CompRess Model Framework

1.1. Goal

  • The goal is to train a deep model (e.g. ResNet-50) using an off-the-shelf self-supervised learning algorithm and then, compress it to a less deep model (e.g., AlexNet).
  • Assuming a frozen teacher embedding t(x) with parameters θt that maps an image x into an N-D feature space, we want to learn the student embedding s(x) with parameters θs that mimics the same behavior as t(x) if used for a downstream supervised task e.g., image classification.

1.2. Framework

  • For simplicity, ti=t(xi) for the embedding of the model t(x) on the input image xi normalized by l2 norm.
  • We have a random set of anchor points (a) for both teacher and student, their corresponding embeddings are:
  • Given a query image qi and its embeddings tqi for teacher and sqi for student, the pairwise similarity is estimated.
  • For the teacher, the probability of the i-th query for the j-th anchor point is:
  • where τ is the temperature.
  • The loss for a particular query point is estimated as the KL divergence between the probabilities over all anchor points under the teacher and student models, and this loss is summed over all query points:
  • where pi(s) is the probability distribution of query i over all anchor points on the student network.
  • Finally, since the teacher is frozen, the student is optimized by solving:

1.3. Implemenations

  • A memory bank: of anchor points from several most recent iterations, is maintained, similar to MoCo v2. Memory bank size of 128,000 is used. Moving average weight for key encoder to 0.999.
  • A small temperature value τ: (less than one) is used so that the model focuses mainly on transferring the relationships from the close neighborhoods of the query rather than faraway points. τ=0.04.
  • Ours-2q: the teacher and student embeddings are decoupled, so used a separate memory bank (queue) for each.
  • Ours-1q: uses the teacher’s anchor points in calculating the similarity for the student model.
  • Caching teacher: Calculating the embeddings for the teacher is expensive in terms of both computation and memory. Caching teacher makes training faster. A drawback is that we cannot augment the images at teacher side.

1.4. Models

  • Three teachers: (a) ResNet-50 model which is trained using MoCo v2 method for 800 epochs [10], (b) ResNet-50 trained with SwAV [8] for 800 epochs, and (c) ResNet-50×4 model which is trained using SimCLR method for 1000 epochs.
  • Compressing from ResNet-50×4 to ResNet-50 takes ~100 hours on four Titan-RTX GPUs while compressing from ResNet-50 to ResNet-18 takes ~90 hours on two 2080-TI GPUs.

2. Results

2.1. Distillation Comparisons on ImageNet

Full ImageNet.
  • Linear, Nearest Neighbor (NN), and Clsuter Assigment (CA) are used.

The proposed method outperforms other distillation methods on all evaluation benchmarks.

Table 2: Full ImageNet, Table 3: Fewer-Label ImageNet.

Table 2: SwAV ResNet-50 is used as the teacher and compress it to ResNet-18. Better accuracy is obtained compared to other distillation methods.

Table 3: For 1-shot, “Ours-2q” model achieves an accuracy close to the supervised model which has seen all labels of ImageNet in learning the features.

2.2. Downstream Tasks

CUB200 and Cars196.

Surprisingly, for the combination of Cars196 dataset and ResNet-50×4 teacher, the proposed model even outperforms the ImageNet supervised model.

Better results are observed for “Ours-2q” on almost all transfer experiments. “Ours-2q” has less restriction.

2.3. SSL SOTA Comparisons

Linear Evaluation on ImageNet and Places.

The proposed method outperforms all baselines on all small capacity architectures (AlexNet, MobileNetV2, and ResNet-18).

  • On AlexNet, it outperforms even the supervised model. Table 6 shows the results of linear classifier using only 1% and 10% of ImageNet for ResNet-50.
Table 6: Smaller ImageNet, Table 7: PASCAL-VOC.

Table 6: ResNet-50 distilled from ResNet-50×4 has competitive or even better performance.

Table 7: Distilled AlexNet has competitive or even better performance.

2.4. Ablation Studies

Ablation and Qualitative Results.
  • 25% ImageNet is used for ablation studies.

(a): The optimal temperature is 0.04.

(b): A larger memory bank results in a more accurate student because when coupled with a small temperature, the large memory bank can help find anchor points that are closer to a query point, thus accurately depicting its close neighborhood.

(c): Randomly selected images from randomly selected clusters for our best AlexNet. Each row is a cluster.

Momentum: Authors do not see any reduction in accuracy by removing the momentum.

Caching Teacher: Caching reduces the accuracy by only a small margin 53.4% to 53.0% on NN and 61.7% to 61.2% on linear evaluation while reducing the running time by a factor of almost 3. For experiments using ResNet-50×4, it cannot afford to do so without caching.

3. Impacts

  • Rich self-supervised features may enable harmful surveillance applications.
  • CompRess may make rich deep models accessible to larger community that do not have access to expensive computation and labeling resources.
  • Model compression enables running deep models on the devices with limited computational and power resources e.g., IoT devices. This reduces the privacy issues.

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.