KnowledgeDistillation Callback

How to apply knowledge distillation with fasterai

We’ll illustrate how to use Knowledge Distillation to distill the knowledge of a Resnet34 (the teacher), to a Resnet18 (the student)

Let’s us grab some data

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The first step is then to train the teacher model. We’ll start from a pretrained model, ensuring to get good results on our dataset.

teacher = cnn_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)
/home/HubensN/miniconda3/envs/deep/lib/python3.8/site-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
epoch train_loss valid_loss accuracy time
0 0.721918 0.643276 0.841678 00:09
1 0.484658 0.604135 0.828146 00:08
2 0.401239 1.103915 0.815291 00:08
3 0.394400 0.318618 0.860622 00:08
4 0.276733 0.276223 0.878890 00:08
5 0.187687 0.515996 0.851150 00:08
6 0.127520 0.230542 0.911367 00:08
7 0.071110 0.233229 0.924222 00:08
8 0.044975 0.199706 0.931664 00:08
9 0.031355 0.177644 0.939784 00:08

Without KD

We’ll now train a Resnet18 from scratch, and without any help from the teacher model, to get that as a baseline

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.608119 0.594279 0.679296 00:07
1 0.577984 0.637746 0.690798 00:07
2 0.543163 0.532345 0.732070 00:07
3 0.508363 0.468151 0.772666 00:07
4 0.464459 0.442890 0.780108 00:07
5 0.405926 0.410481 0.816644 00:07
6 0.355392 0.429471 0.821380 00:07
7 0.278941 0.365873 0.838972 00:07
8 0.218126 0.366222 0.855886 00:07
9 0.165694 0.367872 0.857239 00:07

With KD

And now we train the same model, but with the help of the teacher. The chosen loss is a combination of the regular classification loss (Cross-Entropy) and a loss pushing the student to learn from the teacher’s predictions.

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 2.335700 1.860445 0.692828 00:09
1 2.241398 1.773348 0.727334 00:09
2 2.055018 1.710084 0.723951 00:09
3 1.851421 1.632465 0.761840 00:09
4 1.620585 1.675239 0.755751 00:09
5 1.393245 1.410955 0.774019 00:09
6 1.155736 1.087842 0.826116 00:09
7 0.908853 0.983743 0.838972 00:09
8 0.696537 0.852848 0.857916 00:09
9 0.564625 0.854901 0.857239 00:09

When helped, the student model performs better !

There exist more complicated KD losses, such as the one coming from Paying Attention to Attention, where the student tries to replicate the same attention maps of the teacher at intermediate layers.

Using such a loss requires to be able to specify from which layer we want to replicate those attention maps. To do so, we have to specify them from their string name, which can be obtained with the get_model_layers function.

For example, we set the loss to be applied after each Residual block of our models:

student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)
epoch train_loss valid_loss accuracy time
0 0.088313 0.088667 0.679973 00:09
1 0.079737 0.077369 0.719892 00:09
2 0.070380 0.065641 0.765223 00:09
3 0.061056 0.061554 0.792963 00:09
4 0.055300 0.058515 0.790934 00:09
5 0.048522 0.052656 0.830853 00:09
6 0.040360 0.047567 0.847767 00:09
7 0.032288 0.046334 0.855210 00:09
8 0.023988 0.045383 0.868065 00:09
9 0.020456 0.044370 0.866712 00:09