Introduction

In a previous post, we spoke about Pruning Schedules. More specifically, we introduced the 2 most common ones: One-Shot Pruning and Iterative Pruning, but also, another interesting schedule: Automated Gradual Pruning1


From all of those, the most commonly used pruning schedule is the Iterative Pruning because of its simplicity. The process can be summarized as:

  1. Train the network until convergence
  2. Prune the network
  3. Fine-tune it to recover lost performance
  4. Repeat from step 2 until desired sparsity.

It thus consists of several cycles of pruning/fine-tuning, usually using a smaller learning rate than for the initial training. But is it really better than other schedules ? Also, could we possibly come up with a better schedule ?

Moreover, when using such a schedule, several questions remain open, for example:

  • Should we really wait that the network has converged before pruning ?
  • If we perform the fine-tuning step, should we wait that the network has recovered the lost performance before doing a new pruning step ?
  • How many cycles should we do ?
  • How should we chose the learning rate for fine-tuning?


In this post, we'll try to provide elements of answer to those questions. In particular, we will compare all the pruning schedules and introduce a new one.

We will use the Imagenette dataset in order to perform our experiments. All we need then is to import fastai 2 and the sparse module from fasterai 3

from fastai.vision.all import *
from fasterai.sparse.all import *


Common Schedules

We will first compare the results of the common schedules introduced earlier. To get a fair comparison, all the schedules will be compared on a fixed training budget, using the same hyperparameters.


One-Shot Pruning

The One-Shot Pruning schedule is pretty simple. It consists of 3 phases:

  1. Train the network
  2. Prune a portion of the weights
  3. Fine-Tune the remaining weights

Given that our training budget is fixed, we have to decide wether to put more budget in the initial training phase, i.e. step 1) or in the fine-tuning phase, i.e. step 3). We empirically find that training for $40\%$ of the training budget, and fine-tuning for the remaining time allows for the best results, which suggests that the fine-tuning step is slightly more important than the training one. This can be explained by the fact that, by performing the pruning too late in the training makes it difficult for the network to recover the lost performance as we use a small learning rate towards the end of training.

If we plot the sparsity of our model along the training, one-shot pruning thus looks like this:

We trained our model with $0\%$ sparsity for $40\%$ of our training budget, prune it, then fine-tune for the remaining time. To do this, all is needed is the fasterai SparsifyCallback.

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, one_shot, start_epoch=4)
learn.fit(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.561453 1.426115 0.527643 00:11
1 1.270037 1.601807 0.473121 00:11
2 1.124796 1.406487 0.545987 00:11
3 1.021256 1.357861 0.571465 00:11
4 0.977683 0.965558 0.687643 00:17
5 0.879473 0.840178 0.725860 00:17
6 0.796095 0.793685 0.739873 00:17
7 0.768487 0.825092 0.737325 00:17
8 0.732183 0.865477 0.725605 00:17
9 0.703469 0.756053 0.755159 00:17
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 0.00%
Sparsity at the end of epoch 4: 95.00%
Sparsity at the end of epoch 5: 95.00%
Sparsity at the end of epoch 6: 95.00%
Sparsity at the end of epoch 7: 95.00%
Sparsity at the end of epoch 8: 95.00%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

One-Shot Pruning is often seen as the simplest pruning schedule. We thus have the above results that we can use as a baseline.


Iterative Pruning

As decribed above, the Iterative Pruning Schedule can be broken down as:

  1. Train the network until convergence
  2. Prune the network
  3. Fine-tune it to recover lost performance
  4. Repeat from step 2 until desired sparsity.

Iterative Pruning is slightly different from One-Shot. Here, the pruning doesn't happen in one-step but in several cycles, alternating phase of pruning and fine-tuning. We found that, given a fixed budget, allowing $20\%$ of the training budget for initial training provides best results. For simplicity, we use the same budget of fine-tuning for each fine-tuning phase.

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, iterative, start_epoch=2)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.652174 1.787058 0.414268 00:15
1 1.388728 1.397470 0.530955 00:13
2 1.156839 1.455710 0.572739 00:18
3 1.033233 1.359484 0.570446 00:17
4 0.918170 0.894611 0.712866 00:17
5 0.750040 0.786598 0.753885 00:17
6 0.684740 0.735729 0.764841 00:17
7 1.275929 1.178440 0.606624 00:18
8 1.106493 1.048619 0.661911 00:17
9 1.055405 1.035571 0.668790 00:18
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 31.67%
Sparsity at the end of epoch 3: 31.67%
Sparsity at the end of epoch 4: 63.33%
Sparsity at the end of epoch 5: 63.33%
Sparsity at the end of epoch 6: 63.33%
Sparsity at the end of epoch 7: 95.00%
Sparsity at the end of epoch 8: 95.00%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

As we can see, Iterative Pruning leads to worse results than plain One-Shot Pruning, how come? This is because we imposed a fixed training budget and that, as several works have reported, Iterative Pruning requires a significantly longer fine-tuning process in order to get a better performing pruned network 4.


Automated Gradual Pruning (AGP)

The main problem of the previous schedules is the discontinuity that happens at each pruning step. Indeed, when the pruning is performed, the network sparsity suddenly increases by a lot, making it very difficult for the network to recover its previous performance. More recently, Automated Gradual Pruning was introduced, which allows to vary the pruning frequency, thus making the pruning process "smoother". However, it still requires to set a starting point, which we found to be around $20\%$ of the training.

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, sched_agp, start_epoch=2)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.641825 1.530526 0.501911 00:12
1 1.386651 1.331521 0.565350 00:12
2 1.215227 1.174754 0.607898 00:17
3 1.069576 1.253404 0.596433 00:17
4 0.964384 1.204816 0.629045 00:17
5 0.903077 0.989587 0.684331 00:17
6 0.846415 0.961558 0.685605 00:17
7 0.777686 0.757304 0.753631 00:17
8 0.742017 0.719318 0.766624 00:17
9 0.685183 0.702805 0.771465 00:17
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 31.36%
Sparsity at the end of epoch 3: 54.92%
Sparsity at the end of epoch 4: 71.81%
Sparsity at the end of epoch 5: 83.12%
Sparsity at the end of epoch 6: 89.99%
Sparsity at the end of epoch 7: 93.52%
Sparsity at the end of epoch 8: 94.81%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

With AGP, we can see that we are able to outperform One-Shot Pruning. Indeed, the smoother pruning probably makes it easier for the network to accomodate from the increase of sparsity.



Other schedules

What other "smooth" schedule can we think about ? Fasterai let's you try the schedules available by default in fastai, so let's give them a shot!

Those default schedule are:

  • Annealing Linear
  • Annealing Exponential
  • Annealing Cosine


Linear Schedule

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, sched_lin)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.604190 1.709037 0.472102 00:17
1 1.375103 1.457334 0.534522 00:17
2 1.178370 1.318951 0.561019 00:17
3 1.039873 1.149634 0.624713 00:17
4 0.926598 1.011326 0.674650 00:17
5 0.820038 0.889869 0.720764 00:17
6 0.741654 0.782931 0.738089 00:17
7 0.653089 0.791211 0.744204 00:17
8 0.640457 0.852838 0.729427 00:18
9 1.305882 2.434855 0.220892 00:17
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 9.50%
Sparsity at the end of epoch 1: 19.00%
Sparsity at the end of epoch 2: 28.50%
Sparsity at the end of epoch 3: 38.00%
Sparsity at the end of epoch 4: 47.50%
Sparsity at the end of epoch 5: 57.00%
Sparsity at the end of epoch 6: 66.50%
Sparsity at the end of epoch 7: 76.00%
Sparsity at the end of epoch 8: 85.50%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

The linear schedule looks OK until the very last iteration. As we have seen for Iterative Pruning, the network needs a bit of fine-tuning after pruning some weights, which is not the case here as we continue pruning the weights until the very end, so the sparsity in the network never settles.


Exponential Schedule

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, sched_exp, start_sparsity=0.0001)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.637005 1.904222 0.437962 00:17
1 1.368260 2.153302 0.376561 00:17
2 1.193901 1.408279 0.567898 00:17
3 1.083588 1.378163 0.550318 00:17
4 0.921528 1.076043 0.654013 00:17
5 0.828063 1.027332 0.669554 00:17
6 0.751379 0.730978 0.760255 00:17
7 0.625477 0.683287 0.782166 00:18
8 0.546839 0.601401 0.808408 00:17
9 0.850577 2.636575 0.098344 00:17
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.01%
Sparsity at the end of epoch 3: 0.02%
Sparsity at the end of epoch 4: 0.10%
Sparsity at the end of epoch 5: 0.39%
Sparsity at the end of epoch 6: 1.53%
Sparsity at the end of epoch 7: 6.06%
Sparsity at the end of epoch 8: 23.99%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

Exponential schedule provides even worse results. This was to be expected as the increase of sparsity mostly happen at the end of training (from $24\%$ to $95\%$ in the last epoch), giving the network even less time to recover.


Cosine Schedule

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, sched_cos)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.669174 2.382740 0.333503 00:18
1 1.357350 1.414894 0.556433 00:18
2 1.162969 1.857623 0.497580 00:18
3 1.023889 1.120237 0.635159 00:18
4 0.913671 1.300077 0.585478 00:18
5 0.828866 1.084937 0.663185 00:18
6 0.754994 0.804495 0.742420 00:17
7 0.708546 0.784795 0.745223 00:17
8 0.762766 0.890425 0.712357 00:17
9 1.133683 1.242696 0.591847 00:17
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 2.32%
Sparsity at the end of epoch 1: 9.07%
Sparsity at the end of epoch 2: 19.58%
Sparsity at the end of epoch 3: 32.82%
Sparsity at the end of epoch 4: 47.50%
Sparsity at the end of epoch 5: 62.18%
Sparsity at the end of epoch 6: 75.42%
Sparsity at the end of epoch 7: 85.93%
Sparsity at the end of epoch 8: 92.68%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

Cosine Schedule is a bit better, but we can still see the drop in performance at the end, because the sparsity in the network never settles.


So what can we do from here ?


From what we have seen, Automated Gradual Pruning is the technique that works best so far. AGP possess a long "tail", allowing the network to be fine-tuned with almost no increase in sparsity towards the end of pruning, which is definitely lacking from those default fastai schedules.

Can we modify previous schedules to have a similar behaviour ? What if we artificially add a tail to our cosine schedule ?

In fasterai, this can be done by passing the argument end_epoch, corresponding to the epoch we stop pruning. In this case, it means that we will have 3 entire epochs where the sparsity doesn't change, so the fine-tuning may be more efficient.

Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.668230 2.046625 0.390828 00:18
1 1.352506 1.737922 0.458854 00:18
2 1.175191 2.125888 0.474395 00:18
3 1.032842 1.133045 0.644076 00:18
4 0.984874 1.461341 0.530191 00:18
5 0.926132 1.015559 0.656560 00:18
6 0.902270 0.815804 0.730701 00:18
7 0.727744 0.706084 0.771210 00:18
8 0.682093 0.692142 0.776306 00:18
9 0.660345 0.688686 0.777070 00:18
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 4.70%
Sparsity at the end of epoch 1: 17.88%
Sparsity at the end of epoch 2: 36.93%
Sparsity at the end of epoch 3: 58.07%
Sparsity at the end of epoch 4: 77.12%
Sparsity at the end of epoch 5: 90.30%
Sparsity at the end of epoch 6: 95.00%
Sparsity at the end of epoch 7: 95.00%
Sparsity at the end of epoch 8: 95.00%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

As we can see now, this kind of schedule allows our network to reach similar performance than AGP.



Can we do better ?

It has been showed recently that the most critical phase in the training of a neural network happens during the very first iterations 5 and that applying regularization after that initial transient phase has little effect on the final performance of the network. 6

As network pruning removes some weights, reducing the capacity of the network, it can be seen as a kind of regularization. One thus should apply pruning early in the training to take advantage of its regularization effects but must do so very carefully to not irremediably damage the network during this brittle period.

Can we create a scheduling that gets the best of both worlds, i.e. start pruning slowly right from the start and has a long fine-tuning at the end ? The cosine schedule with a tail seemed to be a good start but is a bit lacking some kind of customization.

Introducing One-Cycle Pruning

We thus introduce One-Cycle Pruning schedule which, as the name suggests, possess only a single cycle of pruning, happening all along the training. The expression of the sparsity along the training is given by:

$$ s_t = s_i + (s_f - s_i) \cdot \frac{1+e^{-\alpha+\beta}}{1+e^{-\alpha t + \beta}} $$

with $s_t$, the level of sparsity at training step $t$, $s_i$ and $s_f$ respectively the initial and final level of sparsity.

This schedule can be customized by varying the slope of pruning (the $\alpha$ parameter) of the offset (the $\beta$ parameter), but we have found that good defaults values are respectively $14$ and $5$.

To use it with fasterai, we only need to create the corresponding function:

def sched_onecycle(start, end, pos, α=14, β=5):
    out = (1+np.exp(-α+β)) / (1 + (np.exp((-α*pos)+β)))
    return start + (end-start)*out

Then use it in the Callback:

learn = Learner(dls, resnet18(num_classes=dls.c), metrics=accuracy)
sp_cb = SparsifyCallback(95, 'weight', 'local', large_final, sched_onecycle)
learn.fit_one_cycle(10, 1e-3, cbs=sp_cb)
Pruning of weight until a sparsity of 95%
epoch train_loss valid_loss accuracy time
0 1.629503 1.880378 0.441783 00:17
1 1.369810 1.583246 0.484841 00:17
2 1.192665 1.541420 0.520510 00:18
3 1.069878 1.174934 0.625478 00:17
4 1.043352 1.543538 0.500892 00:17
5 0.947589 1.077303 0.638981 00:17
6 0.842358 0.792739 0.739873 00:17
7 0.729982 0.711569 0.774013 00:17
8 0.656394 0.652494 0.788280 00:17
9 0.624965 0.645904 0.793376 00:18
Saving Weights at epoch 0
Sparsity at the end of epoch 0: 2.53%
Sparsity at the end of epoch 1: 9.48%
Sparsity at the end of epoch 2: 29.46%
Sparsity at the end of epoch 3: 61.34%
Sparsity at the end of epoch 4: 83.69%
Sparsity at the end of epoch 5: 91.94%
Sparsity at the end of epoch 6: 94.24%
Sparsity at the end of epoch 7: 94.82%
Sparsity at the end of epoch 8: 94.96%
Sparsity at the end of epoch 9: 95.00%
Final Sparsity: 95.00

As we can see, such a schedule allows our network to reach a higher performance given our training budget.



In this blog post, we experimented with a few pruning schedules and showed that, under a strict and fixed training budget, One-Cycle Pruning performs best. If the training budget doesn't matter, then Iterative Pruning might be a good default option

Feel free to also experiment and maybe come up with your own pruning schedule, that perfectly fits your task !




If you notice any mistake or improvement that can be done, please contact me ! If you found that post useful, please consider citing it as:

@article{hubens2021schedule,
  title   = "Which Pruning Schedule Should I Use ?",
  author  = "Hubens, Nathan",
  journal = "nathanhubens.github.io",
  year    = "2021",
  url     = "https://nathanhubens.github.io/posts/deep%20learning/2021/06/15/OneCycle.html"
}

References