from fasterai.core.criteria import *
from fasterai.regularize.all import *
from fastai.vision.all import *
Regularize Callback
Perform Group Regularization in fastai Callback system
Get your data
= untar_data(URLs.PETS)
path = get_image_files(path/"images")
files
def label_func(f): return f[0].isupper()
= ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64)) dls
Train a model without Regularization as a baseline
= vision_learner(dls, resnet18, metrics=accuracy)
learn
learn.unfreeze()
3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.675158 | 0.390553 | 0.845061 | 00:06 |
1 | 0.334588 | 0.219738 | 0.901894 | 00:03 |
2 | 0.178833 | 0.194565 | 0.919486 | 00:04 |
Create the RegularizeCallback
= RegularizeCallback('filter', wd=0.0001) reg_cb
= vision_learner(dls, resnet18, metrics=accuracy)
learn
learn.unfreeze()
3, cbs=reg_cb) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 16.783606 | 15.932514 | 0.855210 | 00:04 |
1 | 14.619355 | 13.426625 | 0.922192 | 00:03 |
2 | 13.077019 | 12.727697 | 0.921516 | 00:04 |