Prune Transformers

Prune transformers architecture with fasterai
Note

This example code is taken from the fastai docs

pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)
path = untar_data(URLs.WIKITEXT_TINY)

Let’s create our fastai Learner.

learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())

And let’s try to extend a given prompt with the pretrained model.

prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"
preds = learn.model.generate(inp, max_length=40, num_beams=5, temperature=1.5)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
tokenizer.decode(preds[0].cpu().numpy())
'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head.\n\nA unicorn is a magical creature with a rainbow tail and a horn'
learn.validate()
(#2) [3.695716619491577,40.2744255065918]
learn.fit_one_cycle(1, 1e-4)
epoch train_loss valid_loss perplexity time
0 3.124115 2.844266 17.188944 07:50
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_2855382/2352043074.py in <cell line: 4>()
      2 inp = tensor(prompt_ids)[None]
      3 
----> 4 preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)
      5 
      6 tokenizer.decode(preds[0].cpu().numpy())

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

~/miniconda3/envs/deep/lib/python3.8/site-packages/transformers/generation_utils.py in generate(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, **model_kwargs)
   1352             )
   1353             # 12. run beam search
-> 1354             return self.beam_search(
   1355                 input_ids,
   1356                 beam_scorer,

~/miniconda3/envs/deep/lib/python3.8/site-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   2203             model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2204 
-> 2205             outputs = self(
   2206                 **model_inputs,
   2207                 return_dict=True,

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/deep/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1046         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1047 
-> 1048         transformer_outputs = self.transformer(
   1049             input_ids,
   1050             past_key_values=past_key_values,

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/deep/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
    832 
    833         if inputs_embeds is None:
--> 834             inputs_embeds = self.wte(input_ids)
    835         position_embeds = self.wpe(position_ids)
    836         hidden_states = inputs_embeds + position_embeds

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    156 
    157     def forward(self, input: Tensor) -> Tensor:
--> 158         return F.embedding(
    159             input, self.weight, self.padding_idx, self.max_norm,
    160             self.norm_type, self.scale_grad_by_freq, self.sparse)

~/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2181         # remove once script supports set_grad_enabled
   2182         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2183     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2184 
   2185 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

Make it sparse !

Let’s see now if we retrain our model, this time introducing sparsity

learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())

Unfortunately, the transformer model uses a custom layer: Conv1D, which is not a part of PyTorch. To overcome this problem, we have to add this layer to our Granularities class, so that it knows what to sparsify.

Here, the Conv1D behaves like a Linear layer, i.e. the weights are defined by a matrix of dimension (nf,nx)

doc(Conv1D)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
    - Avoid using `tokenizers` before the fork if possible
    - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Conv1D

Conv1D(nf, nx)

1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). Basically works like a linear layer but the weights are transposed. Args: nf (`int`): The number of output features. nx (`int`): The number of input features.

We can thus add the Conv1D granularity by using the add_granularity method, indicating the target module and the corresponding granularities that it can handle (the same as Linear so we can reuse it)

Granularities.add_granularity(Conv1D, Granularities._granularities_Linear)

Let’s now define our SparsifyCallback. Let’s say we want to make our model 30% sparse, by removing the highest-norm weight in each attention head.

sp_cb = SparsifyCallback(sparsity=30, granularity='weight', context='local', criteria=large_final, schedule=one_cycle, layer_type=Conv1D)

We now only have to pass our callback to fastai

learn.fit_one_cycle(1, 1e-4, cbs=sp_cb)
Pruning of weight until a sparsity of [30]%
Saving Weights at epoch 0
epoch train_loss valid_loss perplexity time
0 3.151266 2.882525 17.859306 09:44
Sparsity at the end of epoch 0: [30.0]%
Final Sparsity: [30.0]%
Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 23: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 29: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 35: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 41: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 47: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 53: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 59: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 65: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 71: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 77: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 83: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 89: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 95: 30.00%
Sparsity in Conv1D 100: 30.00%
Sparsity in Conv1D 101: 30.00%
Sparsity in Conv1D 106: 30.00%
Sparsity in Conv1D 107: 30.00%
Sparsity in Conv1D 112: 30.00%
Sparsity in Conv1D 113: 30.00%
Sparsity in Conv1D 118: 30.00%
Sparsity in Conv1D 119: 30.00%
Sparsity in Conv1D 124: 30.00%
Sparsity in Conv1D 125: 30.00%
Sparsity in Conv1D 130: 30.00%
Sparsity in Conv1D 131: 30.00%
Sparsity in Conv1D 136: 30.00%
Sparsity in Conv1D 137: 30.00%
Sparsity in Conv1D 142: 30.00%
Sparsity in Conv1D 143: 30.00%
Sparsity in Conv1D 148: 30.00%
Sparsity in Conv1D 149: 30.00%

And we can check the predicion to the same prompt as before

prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn @-@ shaped head. The unicorn is a member of the <unk> <unk>'

That’s it ! You now have a sparse Transformer as performant as the whole model. However, this model is currently not more efficient speed and storage wise. To have such a speed-up, I suggest you to look at the granularity section.