1-
21# Copyright The PyTorch Lightning team.
32#
43# Licensed under the Apache License, Version 2.0 (the "License");
6059
6160DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
6261
63-
6462# --- Finetunning Callback ---
6563
64+
6665class MilestonesFinetuningCallback (BaseFinetuningCallback ):
6766
68- def __init__ (self ,
69- milestones : tuple = (5 , 10 ),
70- train_bn : bool = True ):
67+ def __init__ (self , milestones : tuple = (5 , 10 ), train_bn : bool = True ):
7168 self .milestones = milestones
7269 self .train_bn = train_bn
7370
@@ -78,17 +75,13 @@ def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimi
7875 if epoch == self .milestones [0 ]:
7976 # unfreeze 5 last layers
8077 self .unfreeze_and_add_param_group (
81- module = pl_module .feature_extractor [- 5 :],
82- optimizer = optimizer ,
83- train_bn = self .train_bn
78+ module = pl_module .feature_extractor [- 5 :], optimizer = optimizer , train_bn = self .train_bn
8479 )
8580
8681 elif epoch == self .milestones [1 ]:
8782 # unfreeze remaing layers
8883 self .unfreeze_and_add_param_group (
89- module = pl_module .feature_extractor [:- 5 ],
90- optimizer = optimizer ,
91- train_bn = self .train_bn
84+ module = pl_module .feature_extractor [:- 5 ], optimizer = optimizer , train_bn = self .train_bn
9285 )
9386
9487
@@ -149,10 +142,12 @@ def __build_model(self):
149142 self .feature_extractor = nn .Sequential (* _layers )
150143
151144 # 2. Classifier:
152- _fc_layers = [nn .Linear (2048 , 256 ),
153- nn .ReLU (),
154- nn .Linear (256 , 32 ),
155- nn .Linear (32 , 1 )]
145+ _fc_layers = [
146+ nn .Linear (2048 , 256 ),
147+ nn .ReLU (),
148+ nn .Linear (256 , 32 ),
149+ nn .Linear (32 , 1 ),
150+ ]
156151 self .fc = nn .Sequential (* _fc_layers )
157152
158153 # 3. Loss:
@@ -218,25 +213,21 @@ def setup(self, stage: str):
218213
219214 train_dataset = ImageFolder (
220215 root = data_path .joinpath ("train" ),
221- transform = transforms .Compose (
222- [
223- transforms .Resize ((224 , 224 )),
224- transforms .RandomHorizontalFlip (),
225- transforms .ToTensor (),
226- normalize ,
227- ]
228- ),
216+ transform = transforms .Compose ([
217+ transforms .Resize ((224 , 224 )),
218+ transforms .RandomHorizontalFlip (),
219+ transforms .ToTensor (),
220+ normalize ,
221+ ]),
229222 )
230223
231224 valid_dataset = ImageFolder (
232225 root = data_path .joinpath ("validation" ),
233- transform = transforms .Compose (
234- [
235- transforms .Resize ((224 , 224 )),
236- transforms .ToTensor (),
237- normalize ,
238- ]
239- ),
226+ transform = transforms .Compose ([
227+ transforms .Resize ((224 , 224 )),
228+ transforms .ToTensor (),
229+ normalize ,
230+ ]),
240231 )
241232
242233 self .train_dataset = train_dataset
0 commit comments