= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
KnowledgeDistillation Callback
We’ll illustrate how to use Knowledge Distillation to distill the knowledge of a Resnet34 (the teacher), to a Resnet18 (the student)
Let’s us grab some data
The first step is then to train the teacher model. We’ll start from a pretrained model, ensuring to get good results on our dataset.
= cnn_learner(dls, resnet34, metrics=accuracy)
teacher
teacher.unfreeze()10, 1e-3) teacher.fit_one_cycle(
/home/HubensN/miniconda3/envs/deep/lib/python3.8/site-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code
warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.721918 | 0.643276 | 0.841678 | 00:09 |
1 | 0.484658 | 0.604135 | 0.828146 | 00:08 |
2 | 0.401239 | 1.103915 | 0.815291 | 00:08 |
3 | 0.394400 | 0.318618 | 0.860622 | 00:08 |
4 | 0.276733 | 0.276223 | 0.878890 | 00:08 |
5 | 0.187687 | 0.515996 | 0.851150 | 00:08 |
6 | 0.127520 | 0.230542 | 0.911367 | 00:08 |
7 | 0.071110 | 0.233229 | 0.924222 | 00:08 |
8 | 0.044975 | 0.199706 | 0.931664 | 00:08 |
9 | 0.031355 | 0.177644 | 0.939784 | 00:08 |
Without KD
We’ll now train a Resnet18 from scratch, and without any help from the teacher model, to get that as a baseline
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student 10, 1e-3) student.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.608119 | 0.594279 | 0.679296 | 00:07 |
1 | 0.577984 | 0.637746 | 0.690798 | 00:07 |
2 | 0.543163 | 0.532345 | 0.732070 | 00:07 |
3 | 0.508363 | 0.468151 | 0.772666 | 00:07 |
4 | 0.464459 | 0.442890 | 0.780108 | 00:07 |
5 | 0.405926 | 0.410481 | 0.816644 | 00:07 |
6 | 0.355392 | 0.429471 | 0.821380 | 00:07 |
7 | 0.278941 | 0.365873 | 0.838972 | 00:07 |
8 | 0.218126 | 0.366222 | 0.855886 | 00:07 |
9 | 0.165694 | 0.367872 | 0.857239 | 00:07 |
With KD
And now we train the same model, but with the help of the teacher. The chosen loss is a combination of the regular classification loss (Cross-Entropy) and a loss pushing the student to learn from the teacher’s predictions.
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student = KnowledgeDistillationCallback(teacher.model, SoftTarget)
kd 10, 1e-3, cbs=kd) student.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.335700 | 1.860445 | 0.692828 | 00:09 |
1 | 2.241398 | 1.773348 | 0.727334 | 00:09 |
2 | 2.055018 | 1.710084 | 0.723951 | 00:09 |
3 | 1.851421 | 1.632465 | 0.761840 | 00:09 |
4 | 1.620585 | 1.675239 | 0.755751 | 00:09 |
5 | 1.393245 | 1.410955 | 0.774019 | 00:09 |
6 | 1.155736 | 1.087842 | 0.826116 | 00:09 |
7 | 0.908853 | 0.983743 | 0.838972 | 00:09 |
8 | 0.696537 | 0.852848 | 0.857916 | 00:09 |
9 | 0.564625 | 0.854901 | 0.857239 | 00:09 |
When helped, the student model performs better !
There exist more complicated KD losses, such as the one coming from Paying Attention to Attention
, where the student tries to replicate the same attention maps of the teacher at intermediate layers.
Using such a loss requires to be able to specify from which layer we want to replicate those attention maps. To do so, we have to specify them from their string
name, which can be obtained with the get_model_layers
function.
For example, we set the loss to be applied after each Residual block of our models:
= Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
kd 10, 1e-3, cbs=kd) student.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.088313 | 0.088667 | 0.679973 | 00:09 |
1 | 0.079737 | 0.077369 | 0.719892 | 00:09 |
2 | 0.070380 | 0.065641 | 0.765223 | 00:09 |
3 | 0.061056 | 0.061554 | 0.792963 | 00:09 |
4 | 0.055300 | 0.058515 | 0.790934 | 00:09 |
5 | 0.048522 | 0.052656 | 0.830853 | 00:09 |
6 | 0.040360 | 0.047567 | 0.847767 | 00:09 |
7 | 0.032288 | 0.046334 | 0.855210 | 00:09 |
8 | 0.023988 | 0.045383 | 0.868065 | 00:09 |
9 | 0.020456 | 0.044370 | 0.866712 | 00:09 |