Quantize Callback

Quantize your network during training

QuantizeCallback

 QuantizeCallback (qconfig_mapping=None, backend='x86')

Basic class handling tweaks of the training loop by changing a Learner in various events

import timm
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
learn = Learner(dls, pretrained_resnet_34, metrics=accuracy)
learn.model.fc = nn.Linear(512, 2)
learn.fit_one_cycle(10, cbs=QuantizeCallback())
epoch train_loss valid_loss accuracy time
0 0.593340 0.501450 0.718539 00:12
1 0.395149 0.378658 0.830853 00:11
2 0.247482 0.231792 0.901218 00:11
3 0.164855 0.265116 0.892422 00:11
4 0.112191 0.228632 0.914073 00:11
5 0.070206 0.214058 0.929635 00:11
6 0.050446 0.202638 0.937754 00:11
7 0.033936 0.203362 0.941137 00:11
8 0.024784 0.201417 0.938430 00:11
9 0.022369 0.193307 0.941813 00:12
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit(5, cbs=QuantizeCallback())
epoch train_loss valid_loss accuracy time
0 0.571546 0.424050 0.789581 00:06
1 0.469406 0.363851 0.843708 00:06
2 0.407703 0.399239 0.817997 00:06
3 0.375065 0.309377 0.865359 00:06
4 0.323774 0.331475 0.873478 00:06
learn.model
GraphModule(
  (0): Module(
    (0): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.029317768290638924, zero_point=0, padding=(3, 3))
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.017887497320771217, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.0466480627655983, zero_point=66, padding=(1, 1))
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.017889995127916336, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.07470479607582092, zero_point=66, padding=(1, 1))
      )
    )
    (5): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.0174386166036129, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.047718875110149384, zero_point=60, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=0.04965509846806526, zero_point=68)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.019585009664297104, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.05827442184090614, zero_point=70, padding=(1, 1))
      )
    )
    (6): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=0.02278205193579197, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.05654977634549141, zero_point=57, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=0.019852932542562485, zero_point=75)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.021630365401506424, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.06945421546697617, zero_point=73, padding=(1, 1))
      )
    )
    (7): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=0.019869942218065262, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.07700460404157639, zero_point=63, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=0.045847173780202866, zero_point=68)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.02446889691054821, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.3400780260562897, zero_point=54, padding=(1, 1))
      )
    )
  )
  (1): Module(
    (0): Module(
      (mp): AdaptiveMaxPool2d(output_size=1)
      (ap): AdaptiveAvgPool2d(output_size=1)
    )
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): QuantizedDropout(p=0.25, inplace=False)
    (4): QuantizedLinearReLU(in_features=1024, out_features=512, scale=0.5987890958786011, zero_point=0, qscheme=torch.per_channel_affine)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): QuantizedDropout(p=0.5, inplace=False)
    (8): QuantizedLinear(in_features=512, out_features=2, scale=0.6221694350242615, zero_point=113, qscheme=torch.per_channel_affine)
  )
)