Brief Review — Improved Knowledge Distillation via Teacher Assistant
Teacher Assistant Knowledge Distillation (TAKD)
Improved Knowledge Distillation via Teacher Assistant
Teacher Assistant (TA), by Washington State University, DeepMind, and D.E. Shaw
2020 AAAI, Over 1000 Citations (Sik-Ho Tsang @ Medium)Image Classification
1989 … 2023 [Vision Permutator (ViP)] [ConvMixer] [CrossFormer++] [FastViT] [EfficientFormerV2] [MobileViTv2] [ConvNeXt V2] [SwiftFormer] 2024 [FasterViT]
==== My Other Paper Readings Are Also Over Here ====
- Knowledge distillation (KD) is a popular approach in which a large (teacher) pre-trained network is used to train a smaller (student) network. However, the student network performance degrades when the gap between student and teacher is large.
- In this paper, multi-step knowledge distillation is proposed, namely Teacher Assistant Knowledge Distillation (TAKD), which employs an intermediate-sized network (teacher assistant) to bridge the gap between the student and the teacher.
Outline
- Baseline Knowledge Distillation
- Teacher Assistant Knowledge Distillation (TAKD)
- Results
1. Baseline Knowledge Distillation
1.1. Baseline Knowledge Distillation
- (If you know the basic knowledge distillation, please go to sub-section 1.2. directly.)
- The idea behind knowledge distillation is to have the student network (S) be trained not only via the information provided by true labels but also by observing how the teacher network (T) represents and works with the data.
- Let at and as be the logits (the inputs to the final softmax) of the teacher and student network, respectively.
- In classic supervised learning, the mismatch between output of student network softmax (as) and the ground-truth label yr is usually penalized using cross-entropy loss:
- In knowledge distillation, one also tries to match the softened outputs of student ys = softmax(as/τ) and teacher yt =softmax(at/τ) via a KL divergence loss:
- where the hyperparameter τ referred to temperature is introduced to put additional control on softening of signal.
- The student network is then trained under the following loss function:
- where λ is a second hyperparameter controlling the trade-off between the two losses.
1.2. The Gap Between Student and Teacher
- A plain CNN student with 2 convolutional layers is being trained via distillation with similar but larger teachers of size 4, 6, 8, and 10 on both CIFAR-10 and CIFAR-100 datasets.
With increasing teacher size, its own (test) accuracy increases (plotted in red on the right axis). However, the trained student accuracy first increases and then decreases (depicted in blue on the left axis).
- The explanation behind is that:
- Teacher’s performance increases, thus it provides better supervision for the student by being a better predictor.
- The teacher is becoming so complex that the student does not have the sufficient capacity or mechanics to mimic her behavior despite receiving hints.
- Teacher’s certainty about data increases, thus making its logits (soft targets) less soft. This weakens the knowledge transfer which is done via matching the soft targets.
2. Teacher Assistant Knowledge Distillation (TAKD)
- In this paper, TAKD proposes to use intermediate-size networks to fill in the gap between them.
- The teacher assistant (TA) lies somewhere in between teacher and student in terms of size or capacity.
First, the TA network is distilled from the teacher.
Then, the TA plays the role of a teacher and trains the student via distillation.
3. Results
3.1. KD Comparisons
- Teacher Assistant based method (TAKD) compares with the baseline knowledge distillation (BLKD) and with training normally without any distillation (NOKD), as above.
It is seen the proposed method outperforms both the baseline knowledge distillation and the normal training of neural networks by a reasonable margin.
3.2. Ablation Studies
TA=4 performs better than TA=6 or TA=8. The optimal TA size (4) is actually placed close to the middle in terms of average accuracy rather than the average of size.
For ResNet in Table 3 for CIFAR-10, TA=14 is the optimum, while, for CIFAR-100, TA=20 is the best.
- Multi-step TAs are tried instead of just using 1-step TA.
- First, for all the student sizes (S=2,4,6), TAKD works better than BLKD or NOKD. No matter how many TAs are included in the distillation path.
All multi-step TAKD variants work comparably good and considerably better than BLKD and NOKD. A full path going through all possible intermediate TA networks performs the best.
3.3. Further Analysis
According to Loss Landscape, the proposed network has a flatter surface around the local minima. This is related to robustness against noisy inputs which leads to better generalization.