The code is available here


Introducing FasterAI

FasterAI is a project that I started to make my neural networks smaller and faster with the use of the fastai library. The techniques implemented here can easily be used with plain Pytorch but the idea was to express them in an abstract and easy-to-use manner (à la fastai).

In this article, we'll explain how to use FasterAI by going through an example use-case.


Ready ? Let's dive in then !


Let's start with a bit of context for the purpose of the demonstration. Imagine that we want to deploy a VGG16 model on a mobile device that has limited storage capacity and that our task requires our model to run sufficiently fast. It is known that parameters and speed efficiency are not the strong points of VGG16 but let's see what we can do with it.

Let's first check the number of parameters and the inference time of VGG16.

learn = Learner(data, models.vgg16_bn(num_classes=10), metrics=[accuracy])

So, VGG16 has 134 millions of parameters

Total parameters : 134,309,962

And takes 4.03ms to perform inference on a single image.

4.03 ms ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Snap ! This is more than we can afford for deployment, ideally we would like our model to take only half of that...but should we give up ? Nope, there are actually a lot of techniques that we can use to help reducing the size and improve the speed of our models! Let's see how to apply them with FasterAI.


We will first train our VGG16 model to have a baseline of what performance we should expect from it.

learn.fit_one_cycle(10, 1e-4)
epoch train_loss valid_loss accuracy time
0 2.016354 1.778865 0.368917 01:31
1 1.777570 1.508860 0.523567 01:31
2 1.436139 1.421571 0.569172 01:32
3 1.275864 1.118840 0.630064 01:31
4 1.136620 0.994999 0.687898 01:31
5 0.970474 0.824344 0.739618 01:31
6 0.878756 0.764273 0.765605 01:32
7 0.817084 0.710727 0.781911 01:31
8 0.716041 0.625853 0.804841 01:31
9 0.668815 0.605727 0.810955 01:31

So we would like our network to have comparable accuracy but fewer parameters and running faster... And the first technique that we will show how to use is called Knowledge Distillation




Knowledge Distillation

Knowledge distillation is a simple yet very efficient way to train a model. It was introduced in 2006 by Caruana et al.1. The main idea behind is to use a small model (called the student) to approximate the function learned by a larger and high-performing model (called the teacher). This can be done by using the large model to pseudo-label the data. This idea has been used very recently to break the state-of-the-art accuracy on ImageNet2.

When we train our model for classification, we usually use a softmax as last layer. This softmax has the particularity to squish low value logits towards 0, and the highest logit towards 1. This has for effect to completely lose all the inter-class information, or what is sometimes called the dark knowledge. This is the information that is valuable and that we want to transfer from the teacher to the student.

To do so, we still use a regular classification loss but at the same time, we'll use another loss, computed between the softened logits of the teacher (our soft labels) and the softened logits of the student (our soft predictions). Those soft values are obtained when you use a soft-softmax, that avoids squishing the values at its output. Our implementation follows this paper3 and the basic principle of training is represented in the figure below:


To use Knowledge Distillation with FasterAI, you only need to use this callback when training your student model:


 KnowledgeDistillation(student:Learner, teacher:Learner) 

You only need to give to the callback function your student learner and your teacher learner. Behind the scenes, FasterAI will take care of making your model train using knowledge distillation.


The first thing to do is to find a teacher, which can be any model, that preferrably performs well. We will chose VGG19 for our demonstration. To make sure it performs better than our VGG16 model, let's start from a pretrained version.

teacher = cnn_learner(data, models.vgg19_bn, metrics=[accuracy])
teacher.fit_one_cycle(3, 1e-4)
epoch train_loss valid_loss accuracy time
0 0.249884 0.088749 0.972739 01:02
1 0.201829 0.087495 0.974268 01:02
2 0.261882 0.082631 0.974013 01:01

Our teacher has 97.4% of accuracy which is pretty good, it is ready to take a student under its wing. So let's create our student model and train it with the Knowledge Distillation callback:

student = Learner(data, models.vgg16_bn(num_classes=10), metrics=[accuracy])
student.fit_one_cycle(10, 1e-4, callbacks=[KnowledgeDistillation(student, teacher)])
epoch train_loss valid_loss accuracy time
0 2.323744 2.102873 0.410955 02:16
1 2.099557 2.441147 0.571465 02:16
2 1.829197 2.215419 0.607643 02:16
3 1.617705 1.683477 0.667006 02:16
4 1.364808 1.366435 0.713376 02:16
5 1.257906 0.985063 0.788025 02:16
6 1.087404 0.877424 0.801019 02:17
7 0.949960 0.777630 0.822166 02:16
8 0.868683 0.733206 0.837707 02:17
9 0.756630 0.707806 0.843057 02:16

And we can see that indeed, the knowledge of the teacher was useful for the student, as it is clearly overperforming the vanilla VGG16.

Ok, so now we are able to get more from a given model which is kind of cool ! With some experimentations we could come up with a model smaller than VGG16 but able to reach the same performance as our baseline! You can try to find it by yourself later, but for now let's continue with the next technique !




Sparsifying

Now that we have a student model that is performing better than our baseline, we have some room to compress it. And we'll start by making the network sparse. As explained in a previous article, there are many ways leading to a sparse network.


Note: Usually, the process of making a network sparse is called Pruning. I prefer using the term Pruning when parameters are actually removed from the network, which we will do in the next section.


'Pruning'


By default, FasterAI uses the Automated Gradual Pruning paradigm as it removes parameters as the model trains and doesn't require to pretrain the model, so it is usually much faster. In FasterAI, this is also managed by using a callback, that will replace the least important parameters of your model by zeroes during the training. The callback has a wide variety of parameters to tune your Sparsifying operation, let's take a look at them:


SparsifyCallback(learn, sparsity, granularity, method, criteria, sched_func)
  • sparsity: the percentage of sparsity that you want in your network
  • granularity: on what granularity you want the sparsification to be operated (currently supported: weight, filter)
  • method: either local or global, will affect the selection of parameters to be choosen in each layer independently (local) or on the whole network (global).
  • criteria: the criteria used to select which parameters to remove (currently supported: l1, taylor)
  • sched_func: which schedule you want to follow for the sparsification (currently supported: any scheduling function of fastai, i.e annealing_linear, annealing_cos, ... and annealing_gradual, the schedule proposed by Zhu & Gupta4, represented in Figure below)

'AGP'

Although I found that Automated Gradual Pruning usually works best, you may want to use the other paradigms. They can easily be achieved by doing:

One-Shot Pruning

sparsifier = Sparsifier(granularity, method, criteria)
new_model = sparsifier.prune(learn.model, sparsity)

To perform One-Shot Pruning, you can simply prune your model to the desired sparsity. This is probably highly suboptimal as removing parameters will shake up the model and hurt it quite a bit.


Iterative Pruning

new_model = sparsifier.prune(learn.model, sparsity)
learn = Learner(data, new_model)
learn.fit(num_epochs, lr, callbacks=[SparsifyCallback(learn, sparsity, granularity, method, criteria, sched_func=annealing_no)])
sparsity += increase_value
# REPEAT

To perform Iterative Pruning, we first need to train our model, then perform several iterations of pruning and fine-tuning until desired sparsity. Fine-tuning has to be done with SparsifyCallback and the annealing_no schedule to ensure our zero-weights don't get updated.


But let's come back to our example!

Here, we will make our network 40% sparse, and remove entire filters, selected locally and based on L1 norm. We will train with a learning rate a bit smaller to be gentle with our network because it has already been trained. The scheduling selected is cosinusoidal, so the pruning starts and ends quite slowly.

student.fit(10, 1e-5, callbacks=[SparsifyCallback(student, sparsity=40, granularity='filter', method='local', criteria='l1', sched_func=annealing_cos)])
Pruning of filter until a sparsity of 40%
epoch train_loss valid_loss accuracy time
0 0.584072 0.532074 0.838471 01:34
1 0.583805 0.499353 0.844586 01:34
2 0.599410 0.527805 0.836433 01:34
3 0.610081 0.544566 0.828025 01:35
4 0.625637 0.543279 0.829809 01:34
5 0.628777 0.563051 0.819618 01:34
6 0.688617 0.617627 0.800000 01:34
7 0.691044 0.629927 0.801019 01:34
8 0.669935 0.576220 0.814013 01:33
9 0.682428 0.562718 0.823949 01:34
Sparsity at epoch 0: 0.98%
Sparsity at epoch 1: 3.83%
Sparsity at epoch 2: 8.25%
Sparsity at epoch 3: 13.83%
Sparsity at epoch 4: 20.01%
Sparsity at epoch 5: 26.19%
Sparsity at epoch 6: 31.76%
Sparsity at epoch 7: 36.19%
Sparsity at epoch 8: 39.02%
Sparsity at epoch 9: 40.00%
Final Sparsity: 40.00

Our network now has 40% of its filters composed entirely of zeroes, at the cost of 2% of accuracy. Obviously, choosing a higher sparsity, makes it more difficult for the network to keep a similar accuracy. Other parameters can also widely change the behaviour of our sparsification process. For example choosing a more fine-grained sparsity usually leads to better results but is then more difficult to take advantage of in terms of speed.

We can double-check that our model has indeed been pruned by 40% of its parameters.

Sparsity in Conv2d 2: 39.06%
Sparsity in Conv2d 5: 39.06%
Sparsity in Conv2d 9: 39.84%
Sparsity in Conv2d 12: 39.84%
Sparsity in Conv2d 16: 39.84%
Sparsity in Conv2d 19: 39.84%
Sparsity in Conv2d 22: 39.84%
Sparsity in Conv2d 26: 39.84%
Sparsity in Conv2d 29: 39.84%
Sparsity in Conv2d 32: 39.84%
Sparsity in Conv2d 36: 39.84%
Sparsity in Conv2d 39: 39.84%
Sparsity in Conv2d 42: 39.84%

We don't have exactly 40% because, as we removed complete filters, we don't necesserally have a round number.


Let's now see how much we gained in terms of speed. Because we removed 40% of convolution filters, we should expect crazy speed-up right ?

4.02 ms ± 5.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Well actually, no. We didn't remove any parameters, we just replaced some by zeroes, remember? The amount of parameters is still the same:

Total parameters : 134,309,962

Which leads us to the next section.




Pruning

Important: This is currently only supported for fully-feedforward models such as VGG-like models as more complex architectures require increasingly difficult and usually model-dependant implementations.


Why don't we see any acceleration even though we removed half of the parameters? That's because natively, our GPU does not know that our matrices are sparse and thus isn't able to accelerate the computation. The easiest work around, is to physically remove the parameters we zeroed-out. But this operation requires to change the architecture of the network.

This pruning only works if we have zeroed-out entire filters beforehand as it is the only case where you can change the architecture accordingly. Hopefully, sparse computations will soon be available on common deep learning librairies so this section will become useless in the future, but for the moment, it is the best solution I could come up with 🤷


Here is what it looks like with fasterai:


pruner = Pruner()
pruned_model = pruner.prune_model(learn.model)

You just need to pass the model whose filters has previously been sparsified and FasterAI will take care of removing them.

Note: This operation should be lossless as it only removes filters that already do not participate in the network anymore.


So in the case of our example, it gives:

pruner = Pruner()
pruned_model = pruner.prune_model(student.model)

Let's now see what our model is capable of now:

Total parameters : 83,975,344

And in terms of speed:

2.44 ms ± 3.51 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Yay ! Now we can talk ! Let's just double check that our accuracy is unchanged and that we didn't mess up somewhere:

[0.5641388, tensor(0.8229)]


And there is actually more that we can do ! Let's keep going !




Batch Normalization Folding

Batch Normalization Folding is a really easy to implement and straightforward idea. The gist is that batch normalization is nothing more than a normalization of the input data at each layer. Moreover, at inference time, the batch statistics used for this normalization are fixed. We can thus incorporate the normalization process directly in the convolution by changing its weights and completely remove the batch normalization layers, which is a gain both in terms of parameters and in terms of computations. For a more in-depth explaination, see my previous post.

This is how to use it with FasterAI:


bn_folder = BN_Folder()
bn_folder.fold(learn.model))

Again, you only need to pass your model and FasterAI takes care of the rest. For models built using the nn.Sequential, you don't need to change anything. For others, if you want to see speedup and compression, you actually need to subclass your model to remove the batch norm from the parameters and from the forward method of your network.

Note: This operation should also be lossless as it redefines the convolution to take batch norm into account and is thus equivalent.


Let's do this with our model !

folded_model = bn_folding_model(pruned_learner.model)

The parameters drop is generally not that significant, especially in a network such as VGG where almost all parameters are contained in the FC layers but, hey, any gain is good to take.

Total parameters : 83,970,260


Now that we removed the batch normalization layers, we should again see a speedup.

2.27 ms ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Again, let's double check that we didn't mess up somewhere:

[0.5641388, tensor(0.8229)]


And we're still not done yet ! As we know for VGG16, most of the parameters are comprised in the fully-connected layers so there should be something that we can do about it, right ?




FC Layers Factorization

We can indeed, factorize our big fully-connected layers and replace them by an approximation of two smaller layers. The idea is to make an SVD decomposition of the weight matrix, which will express the original matrix in a product of 3 matrices: $U \Sigma V^T$. With $\Sigma$ being a diagonal matrix with non-negative values along its diagonal (the singular values). We then define a value $k$ of singular values to keep and modify matrices $U$ and $V^T$ accordingly. The resulting will be an approximation of the initial matrix.

'SVD'

In FasterAI, to decompose the fully-connected layers of your model, here is what you need to do:


FCD = FCDecomposer()
decomposed_model = FCD.decompose(model, percent_removed)

The percent_removed corresponds to the percentage of singular values removed (k value above).

Note: This time, the decomposition is not exact, so we expect a drop in performance afterwards and further retraining will be needed.


Which gives with our example, if we only want to keep half of them:

fc_decomposer = FCDecomposer()
decomposed_model = fc_decomposer.decompose(folded_model, percent_removed=0.5)

How many parameters do we have now ?

Total parameters : 61,430,022

And how much time did we gain ?

2.11 ms ± 462 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


However, this technique is an approximation so it is not lossless, so we should retrain our network a bit to recover its performance.

final_learner = Learner(data, decomposed_model, metrics=[accuracy])
final_learner.fit_one_cycle(5, 1e-5)
epoch train_loss valid_loss accuracy time
0 0.795416 0.759886 0.772994 00:51
1 0.752566 0.701141 0.794395 00:52
2 0.700373 0.650178 0.804841 00:51
3 0.604264 0.606801 0.821656 00:51
4 0.545705 0.592318 0.823185 00:52

This operation is usually less useful for more recent architectures as they usually do not have that many parameters in their fully-connected layers.




So to recap, we saw in this article how to use fasterai to:

  1. Make a student model learn from a teacher model (Knowledge Distillation)
  2. Make our network sparse (Sparsifying)
  3. Optionnaly physically remove the zero-filters (Pruning)
  4. Remove the batch norm layers (Batch Normalization Folding)
  5. Approximate our big fully-connected layers by smaller ones (Fully-Connected Layers Factorization)


And we saw that by applying those, we could reduce our VGG16 model from 134 million of parameters down to 61 million, and also speed-up the inference from 4.03ms to 2.11ms without any drop in accuracy (even a slight increase actually) compared to the baseline.

Of course, those techniques can be used in conjunction with quantization or mixed-precision training, which are already available in Pytorch for even more compression and speedup.


Note: Please keep in mind that the techniques presented above are not magic 🧙‍♂️, so do not expect to see a 200% speedup and compression everytime. What you can achieve highly depend on the architecture that you are using (some are already speed/parameter efficient by design) or the task it is doing (some datasets are so easy that you can remove almost all your network without seeing a drop in performance)


That's all! Thank you for reading, I hope that you'll like FasterAI. I do not claim that it is perfect, you'll probably find a lot of bugs. If you do, just please tell me, so I can try to solve them 😌




If you notice any mistake or improvement that can be done, please contact me ! If you found that post useful, please consider citing it as:

@article{hubens2020fasterai,
  title   = "FasterAI",
  author  = "Hubens, Nathan",
  journal = "nathanhubens.github.io",
  year    = "2020",
  url     = "https://nathanhubens.github.io/posts/deep%20learning/2020/08/17/FasterAI.html"
}

References