Prune Callback

Use the pruner in fastai Callback system

Let’s try our PruneCallback on the Pets dataset

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

We’ll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.653510 0.805452 0.831529 00:08
1 0.373264 0.246071 0.901894 00:03
2 0.226383 0.212931 0.912043 00:03
3 0.118254 0.186566 0.920162 00:03
4 0.067994 0.185255 0.924899 00:03
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

Let’s now try adding to remove some filters in our model

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

We’ll set the sparsity to 50 (i.e. remove 50% of filters), the context to global (i.e. we remove filters from anywhere in the network), the criteria to large_final (i.e. keep the highest value filters and the schedule to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).

pr_cb = PruneCallback(sparsity=45, context='global', criteria=large_final, schedule=one_cycle, layer_type=[nn.Conv2d])
learn.fit_one_cycle(10, cbs=pr_cb)
Pruning until a sparsity of [45]%
Sparsity at the end of epoch 0: [0.45]%
Sparsity at the end of epoch 1: [1.76]%
Sparsity at the end of epoch 2: [6.39]%
Sparsity at the end of epoch 3: [18.07]%
Sparsity at the end of epoch 4: [32.91]%
Sparsity at the end of epoch 5: [41.27]%
Sparsity at the end of epoch 6: [44.03]%
Sparsity at the end of epoch 7: [44.77]%
Sparsity at the end of epoch 8: [44.95]%
Sparsity at the end of epoch 9: [45.0]%
Final Sparsity: [45.0]%
epoch train_loss valid_loss accuracy time
0 0.923606 0.602696 0.727334 00:04
1 0.524770 0.355539 0.857239 00:05
2 0.352097 0.269183 0.890392 00:05
3 0.270549 0.325706 0.890392 00:05
4 0.205549 0.192651 0.920162 00:06
5 0.155798 0.221400 0.908660 00:05
6 0.137832 0.197844 0.907984 00:05
7 0.109144 0.196927 0.924222 00:05
8 0.085867 0.183181 0.933694 00:05
9 0.084885 0.186639 0.927605 00:05
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

We observe that our network has lost 2.5% of accuracy. But how much parameters have we removed and how much compute does that save ?

print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')
The pruned model has 0.75 the compute of original model
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')
The pruned model has 0.26 the parameters of original model

So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.