= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
Prune Callback
Use the pruner in fastai Callback system
Let’s try our PruneCallback
on the Pets
dataset
We’ll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance
= vision_learner(dls, resnet18, metrics=accuracy)
learn
learn.unfreeze()5) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.653510 | 0.805452 | 0.831529 | 00:08 |
1 | 0.373264 | 0.246071 | 0.901894 | 00:03 |
2 | 0.226383 | 0.212931 | 0.912043 | 00:03 |
3 | 0.118254 | 0.186566 | 0.920162 | 00:03 |
4 | 0.067994 | 0.185255 | 0.924899 | 00:03 |
= tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device())) base_macs, base_params
Let’s now try adding to remove some filters in our model
= vision_learner(dls, resnet18, metrics=accuracy)
learn learn.unfreeze()
We’ll set the sparsity
to 50 (i.e. remove 50% of filters), the context
to global (i.e. we remove filters from anywhere in the network), the criteria
to large_final (i.e. keep the highest value filters and the schedule
to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).
= PruneCallback(sparsity=45, context='global', criteria=large_final, schedule=one_cycle, layer_type=[nn.Conv2d])
pr_cb 10, cbs=pr_cb) learn.fit_one_cycle(
Pruning until a sparsity of [45]%
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.923606 | 0.602696 | 0.727334 | 00:04 |
1 | 0.524770 | 0.355539 | 0.857239 | 00:05 |
2 | 0.352097 | 0.269183 | 0.890392 | 00:05 |
3 | 0.270549 | 0.325706 | 0.890392 | 00:05 |
4 | 0.205549 | 0.192651 | 0.920162 | 00:06 |
5 | 0.155798 | 0.221400 | 0.908660 | 00:05 |
6 | 0.137832 | 0.197844 | 0.907984 | 00:05 |
7 | 0.109144 | 0.196927 | 0.924222 | 00:05 |
8 | 0.085867 | 0.183181 | 0.933694 | 00:05 |
9 | 0.084885 | 0.186639 | 0.927605 | 00:05 |
Sparsity at the end of epoch 0: [0.45]%
Sparsity at the end of epoch 1: [1.76]%
Sparsity at the end of epoch 2: [6.39]%
Sparsity at the end of epoch 3: [18.07]%
Sparsity at the end of epoch 4: [32.91]%
Sparsity at the end of epoch 5: [41.27]%
Sparsity at the end of epoch 6: [44.03]%
Sparsity at the end of epoch 7: [44.77]%
Sparsity at the end of epoch 8: [44.95]%
Sparsity at the end of epoch 9: [45.0]%
Final Sparsity: [45.0]%
= tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device())) pruned_macs, pruned_params
We observe that our network has lost 2.5% of accuracy. But how much parameters have we removed and how much compute does that save ?
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')
The pruned model has 0.75 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')
The pruned model has 0.26 the parameters of original model
So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.