= torch.tensor([1.3, 3.1, 0.2, 1.9, -0.3]) logits
Knowledge Distillation
Knowledge Distillation, sometimes called teacher-student training, is a compression method in which a small (the student) model is trained to mimic the behaviour of a larger (the teacher) model.
The main goal is to reveal what is called the Dark Knowledge hidden in the teacher model.
If we take the same example provided by Geoffrey Hinton et al., we have
The main problem of classification is that the output activation function (softmax) will, by design, make a single value really high and squash others.
\[ p_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)} \]
With \(p_i\) the probability of class \(i\), computed from the logits \(z\)
Here is an example to illustrate this phenomenon:
Let’s say that we have trained a model to discriminate between the following 5 classes: [cow, dog, plane, cat, car]
And here is the output of the final layer (the logits) when the model is fed a new input image:
By judging on the predictions, the model seems confident that the input data is a dog and quite confident that it is definitely not a plane nor a car, with predictions for cow and cat being moderately high.
So the model not only has learned to recognize a dog in the image, but also that a dog is very different from a car and a plane and share similarities with cats and cows. This information is what is called dark knowledge !
When passing those predictions through a softmax, we have:
= F.softmax(logits, dim=-1); predictions predictions
tensor([0.1063, 0.6431, 0.0354, 0.1937, 0.0215])
This is accuenting the differences that we had earlier, discarding some of the dark knowledge acquired earlier. The way to keep this knowledge is to “soften” our softmax outputs, by adding a temperature parameter. The higher the temperature, the softer the predictions.
= F.softmax(logits/3, dim=-1); soft_predictions soft_predictions
tensor([0.1879, 0.3423, 0.1302, 0.2294, 0.1102])
if the Temperature is equal to 1, then we have regular softmax
When applying Knowledge Distillation, we want to keep the Dark Knowledge that the teacher model has acquired during its training but not rely entirely on it. So we combine two losses:
- The Teacher loss between the softened predictions of the teacher and the softened predictions of the student
- The Classification loss, which is the regular loss between hard labels and hard predictions
The combination between those losses are weighted by an additional parameter α, as:
\[ L_{K D}=\alpha * \text { CrossEntropy }\left(p_{S}^{\tau}, p_{T}^{\tau}\right)+(1-\alpha) * \text { CrossEntropy }\left(p_{S}, y_{\text {true }}\right) \]
With \(p^{\tau}\) being the softened predictions of the student and teacher
In practice, the distillation loss will be a bit different in the implementation
This can be done with fastai, using the Callback system !
KnowledgeDistillationCallback
KnowledgeDistillationCallback (teacher, loss, activations_student=None, activations_teacher=None, weight=0.5)
Basic class handling tweaks of the training loop by changing a Learner
in various events
The loss function that is used may depend on the use case. For classification, we usually use the one presented above, named SoftTarget
in fasterai. But for regression cases, we may want to perform regression on the logits directly.