Sparsify Callback

Use the sparsifier in fastai Callback system
from fastai.vision.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The most important part of our Callback happens in before_batch. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

import timm
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn.fc = nn.Linear(512, 2)
sp_cb = SparsifyCallback(sparsity=50, granularity='filter', context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of filter until a sparsity of [50]%
Saving Weights at epoch 0
Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 1: 50.00%
Sparsity in Conv2d 7: 53.12%
Sparsity in Conv2d 12: 50.47%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 21: 50.53%
Sparsity in Conv2d 25: 50.00%
Sparsity in Conv2d 30: 50.41%
Sparsity in Conv2d 35: 50.00%
Sparsity in Conv2d 40: 50.24%
Sparsity in Conv2d 44: 50.00%
Sparsity in Conv2d 47: 50.00%
Sparsity in Conv2d 52: 50.24%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 61: 50.38%
Sparsity in Conv2d 65: 50.00%
Sparsity in Conv2d 70: 50.11%
Sparsity in Conv2d 75: 50.00%
Sparsity in Conv2d 80: 50.12%
Sparsity in Conv2d 84: 50.00%
Sparsity in Conv2d 87: 50.00%
Sparsity in Conv2d 92: 50.09%
Sparsity in Conv2d 96: 50.00%
Sparsity in Conv2d 101: 50.10%
Sparsity in Conv2d 105: 50.00%
Sparsity in Conv2d 110: 50.03%
Sparsity in Conv2d 114: 50.00%
Sparsity in Conv2d 119: 50.11%
Sparsity in Conv2d 123: 50.00%
Sparsity in Conv2d 128: 50.12%
Sparsity in Conv2d 133: 50.00%
Sparsity in Conv2d 138: 50.08%
Sparsity in Conv2d 142: 50.00%
Sparsity in Conv2d 145: 50.01%
Sparsity in Conv2d 150: 50.14%
Sparsity in Conv2d 154: 50.00%
Sparsity in Conv2d 159: 50.28%
epoch train_loss valid_loss accuracy time
0 0.434834 0.498019 0.758457 00:07
1 0.409158 0.432628 0.809202 00:06
2 0.338707 0.371978 0.832206 00:07
3 0.274749 0.368416 0.854533 00:06
4 0.237638 0.373818 0.849797 00:06
learn.model.conv1.weight
Parameter containing:
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00,  0.0000e+00],
          ...,
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00,  0.0000e+00],
          ...,
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          ...,
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00]]],


        ...,


        [[[-7.4296e-03, -7.5463e-03, -8.3205e-03,  ..., -7.1928e-02,
            1.5917e-01,  1.2489e-01],
          [ 1.0137e-01, -5.4320e-03, -7.5936e-03,  ..., -4.2560e-01,
           -1.0976e-02,  1.9817e-01],
          [ 1.2293e-01,  2.8018e-01,  3.2716e-01,  ..., -6.7938e-01,
           -4.9350e-01, -1.6077e-02],
          ...,
          [-1.8629e-01, -1.9095e-01, -2.0495e-02,  ...,  5.9457e-01,
           -1.7923e-02, -2.3116e-01],
          [-1.1156e-01, -2.5112e-01, -3.7146e-01,  ...,  3.9388e-01,
            2.8940e-01, -1.7420e-02],
          [-9.4251e-03, -1.6436e-02, -3.3763e-01,  ..., -1.1631e-02,
            1.7691e-01,  1.0555e-01]],

         [[-8.8509e-03, -6.3883e-03, -5.0388e-03,  ...,  4.3923e-03,
            1.0098e-01,  6.4890e-03],
          [-7.3702e-03, -2.3618e-03, -2.8961e-03,  ..., -2.6874e-01,
           -2.7887e-03,  1.6050e-01],
          [-5.0071e-03,  1.6799e-01,  2.7184e-01,  ..., -5.8726e-01,
           -3.1158e-01, -6.8676e-03],
          ...,
          [-1.5120e-01, -2.3115e-01, -1.0500e-01,  ...,  4.9077e-01,
           -8.2684e-03, -1.3619e-01],
          [-7.1860e-03, -1.8193e-01, -3.3559e-01,  ...,  2.7722e-01,
            1.6240e-01, -9.3599e-03],
          [-2.4150e-03, -7.1116e-03, -1.8751e-01,  ..., -1.6245e-04,
            1.1394e-01, -8.0694e-03]],

         [[-3.9283e-03, -3.3088e-03, -3.5065e-03,  ...,  3.9558e-03,
            1.9028e-03,  3.6099e-03],
          [-1.2535e-03,  1.3431e-03, -7.3769e-04,  ..., -2.8860e-03,
           -3.9271e-03, -4.3433e-03],
          [ 1.9957e-03,  6.6360e-03,  2.0233e-03,  ..., -1.9471e-01,
           -7.5721e-03, -8.9046e-03],
          ...,
          [-4.8502e-03, -9.2381e-02, -7.7613e-02,  ...,  1.2395e-01,
           -1.0181e-02, -9.2624e-03],
          [ 1.8123e-03, -3.5798e-03, -1.1415e-01,  ..., -5.0590e-03,
           -8.1797e-03, -9.5528e-03],
          [ 5.2121e-03, -2.0169e-04,  1.2455e-03,  ...,  1.7431e-03,
           -4.9568e-03, -8.3131e-03]]],


        [[[ 7.8349e-03,  1.3473e-02,  8.7682e-02,  ...,  1.4227e-01,
            1.4403e-02,  1.1514e-02],
          [ 5.7151e-03,  1.5254e-01, -1.6662e-01,  ..., -1.9651e-01,
            9.3270e-03,  9.1108e-03],
          [ 8.3750e-02, -6.4736e-02, -7.7611e-02,  ...,  1.6225e-01,
           -2.5000e-01,  2.0282e-01],
          ...,
          [-9.0023e-02,  1.0453e-01, -1.6298e-01,  ..., -1.1992e-01,
            3.8632e-03, -1.2307e-01],
          [ 1.1366e-02,  1.7510e-01, -2.6541e-01,  ...,  5.9963e-03,
           -2.4098e-01,  2.6635e-01],
          [ 9.7180e-03,  1.6214e-01, -2.6019e-01,  ..., -1.3749e-01,
           -6.4509e-02, -1.0232e-03]],

         [[ 1.8908e-02,  2.6516e-02, -8.7681e-02,  ...,  2.7610e-02,
            2.3307e-02,  1.2432e-01],
          [ 1.5978e-02,  1.7737e-02,  2.3648e-02,  ...,  2.1288e-02,
            1.9111e-02, -1.5350e-01],
          [ 1.8378e-02, -8.8506e-02,  3.1728e-01,  ..., -2.9368e-01,
            5.7689e-01, -1.4814e-01],
          ...,
          [-7.0136e-02,  2.9066e-02, -6.5593e-02,  ...,  3.0752e-03,
           -3.0135e-01,  1.3104e-01],
          [ 1.6101e-02,  2.9628e-01, -1.2312e-01,  ...,  1.9397e-01,
            2.5778e-01,  5.0888e-03],
          [ 1.4030e-02, -2.0212e-01,  3.2355e-01,  ...,  7.0312e-03,
            8.5274e-03, -1.1626e-01]],

         [[ 1.8303e-02,  2.4612e-02,  2.9882e-02,  ..., -9.8176e-02,
            2.3368e-02,  1.9999e-02],
          [ 1.5427e-02,  1.6103e-02,  1.4962e-01,  ...,  2.1504e-01,
            1.8854e-02,  1.6187e-01],
          [-5.3391e-02,  1.4294e-01, -2.2367e-01,  ...,  1.3511e-02,
           -1.8316e-01, -8.3739e-02],
          ...,
          [ 2.1427e-01, -1.5982e-01,  1.8060e-01,  ..., -9.3084e-02,
            4.8867e-01, -1.2281e-01],
          [ 9.6726e-03, -3.7949e-01,  3.9740e-01,  ..., -1.5790e-01,
           -1.5290e-01, -1.8989e-01],
          [ 7.3821e-03,  7.1633e-03,  2.5753e-03,  ..., -3.3214e-03,
            1.5466e-01, -6.4393e-03]]],


        [[[ 1.4643e-01,  5.8291e-03, -9.3539e-02,  ..., -1.0934e-01,
           -6.7021e-02,  1.4150e-02],
          [-1.4916e-01, -3.0628e-01, -3.8497e-01,  ..., -3.3362e-01,
           -2.2944e-01, -1.6075e-01],
          [-1.6706e-01, -2.8656e-01, -3.0256e-01,  ..., -2.0954e-01,
           -1.3465e-01,  9.8437e-03],
          ...,
          [ 5.5858e-03,  6.4734e-03,  1.8051e-01,  ...,  3.8381e-01,
            3.5352e-01,  2.8040e-01],
          [ 8.6895e-03,  1.1563e-01,  2.1870e-01,  ...,  3.3523e-01,
            2.3642e-01,  1.5989e-01],
          [ 9.8079e-02,  1.6960e-01,  2.1411e-01,  ...,  1.8952e-01,
            6.1990e-03,  1.1999e-02]],

         [[ 1.7169e-01,  3.0308e-03, -1.3982e-01,  ..., -1.6779e-01,
           -1.0978e-01,  7.9350e-03],
          [-1.9304e-01, -4.0593e-01, -5.3100e-01,  ..., -4.6119e-01,
           -3.2400e-01, -2.0214e-01],
          [-2.8363e-01, -4.3816e-01, -4.9162e-01,  ..., -3.1524e-01,
           -2.1846e-01, -8.9559e-02],
          ...,
          [ 4.0925e-03,  2.3280e-03,  2.0114e-01,  ...,  4.4946e-01,
            4.0134e-01,  3.4598e-01],
          [ 1.1379e-01,  1.8643e-01,  3.1593e-01,  ...,  4.7618e-01,
            3.8119e-01,  3.1284e-01],
          [ 1.8202e-01,  2.5852e-01,  3.0201e-01,  ...,  3.2789e-01,
            1.8261e-01,  1.0490e-01]],

         [[ 1.4723e-01, -1.3924e-02, -1.1114e-02,  ..., -3.3590e-03,
           -3.8765e-03, -9.1093e-04],
          [-1.7495e-02, -1.6472e-01, -2.6806e-01,  ..., -2.4088e-01,
           -1.9225e-01, -1.3964e-01],
          [-1.1816e-01, -2.2255e-01, -2.8439e-01,  ..., -1.7405e-01,
           -1.1457e-01, -9.2854e-03],
          ...,
          [-1.5639e-02, -1.7753e-02, -2.0895e-02,  ...,  2.2529e-01,
            1.7778e-01,  1.4106e-01],
          [-1.3705e-02, -1.7075e-02,  9.4111e-02,  ...,  2.2573e-01,
            1.6869e-01,  1.2757e-01],
          [ 1.0739e-01,  1.3187e-01,  1.1567e-01,  ...,  1.6808e-01,
            1.1086e-01,  8.2069e-02]]]], device='cuda:0', requires_grad=True)
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 0.711885 1.064277 0.843708 00:45
1 0.409735 0.217008 0.913396 00:03
2 0.265280 0.284833 0.898512 00:03
3 0.144334 0.158726 0.936401 00:03
4 0.082726 0.153889 0.939784 00:03

Let’s now try adding some sparsity in our model

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The SparsifyCallback requires a new argument compared to the Sparsifier. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already available in fastai or come up with your own ! For more information about the pruning schedules, take a look at the Schedules section.

sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0
Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%
epoch train_loss valid_loss accuracy time
0 0.711270 0.742762 0.775372 00:07
1 0.383374 0.307700 0.864005 00:07
2 0.219235 0.217708 0.905954 00:07
3 0.121921 0.213659 0.933018 00:07
4 0.067208 0.200506 0.930988 00:07

Surprisingly, our network that is composed of \(50 \%\) of zeroes performs reasonnably well when compared to our plain and dense network.

The SparsifyCallback also accepts a list of sparsities, corresponding to each layer of layer_type to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)
learn.fit_one_cycle(5, cbs=sp_cb)
Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0
Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity in Conv2d 2: 0.00%
Sparsity in Conv2d 8: 0.00%
Sparsity in Conv2d 11: 0.00%
Sparsity in Conv2d 14: 0.00%
Sparsity in Conv2d 17: 0.00%
Sparsity in Conv2d 21: 0.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 0.00%
Sparsity in Conv2d 53: 0.00%
Sparsity in Conv2d 56: 0.00%
Sparsity in Conv2d 59: 0.00%
Sparsity in Conv2d 62: 0.00%
Sparsity in Conv2d 65: 0.00%
epoch train_loss valid_loss accuracy time
0 0.731650 0.570400 0.811908 00:07
1 0.396108 0.262083 0.895805 00:07
2 0.250992 0.210679 0.909337 00:07
3 0.132799 0.192091 0.925575 00:07
4 0.079732 0.159255 0.938430 00:07

On top of that, the SparsifyCallbackcan also take many optionnal arguments:

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !