3737Note:
3838 See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
3939"""
40- import argparse
40+
4141import logging
42- import os
4342from pathlib import Path
4443from typing import Union
4544
5958from pytorch_lightning import LightningDataModule
6059from pytorch_lightning .callbacks .finetuning import BaseFinetuning
6160from pytorch_lightning .utilities import rank_zero_info
61+ from pytorch_lightning .utilities .cli import LightningCLI
6262
6363log = logging .getLogger (__name__ )
6464DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
@@ -93,10 +93,17 @@ class CatDogImageDataModule(LightningDataModule):
9393
9494 def __init__ (
9595 self ,
96- dl_path : Union [str , Path ],
96+ dl_path : Union [str , Path ] = "data" ,
9797 num_workers : int = 0 ,
9898 batch_size : int = 8 ,
9999 ):
100+ """CatDogImageDataModule
101+
102+ Args:
103+ dl_path: root directory where to download the data
104+ num_workers: number of CPU workers
105+ batch_size: number of sample in a batch
106+ """
100107 super ().__init__ ()
101108
102109 self ._dl_path = dl_path
@@ -146,17 +153,6 @@ def val_dataloader(self):
146153 log .info ("Validation data loaded." )
147154 return self .__dataloader (train = False )
148155
149- @staticmethod
150- def add_model_specific_args (parent_parser ):
151- parser = parent_parser .add_argument_group ("CatDogImageDataModule" )
152- parser .add_argument (
153- "--num-workers" , default = 0 , type = int , metavar = "W" , help = "number of CPU workers" , dest = "num_workers"
154- )
155- parser .add_argument (
156- "--batch-size" , default = 8 , type = int , metavar = "W" , help = "number of sample in a batch" , dest = "batch_size"
157- )
158- return parent_parser
159-
160156
161157# --- Pytorch-lightning module ---
162158
@@ -166,17 +162,22 @@ class TransferLearningModel(pl.LightningModule):
166162 def __init__ (
167163 self ,
168164 backbone : str = "resnet50" ,
169- train_bn : bool = True ,
170- milestones : tuple = (5 , 10 ),
165+ train_bn : bool = False ,
166+ milestones : tuple = (2 , 4 ),
171167 batch_size : int = 32 ,
172- lr : float = 1e-2 ,
168+ lr : float = 1e-3 ,
173169 lr_scheduler_gamma : float = 1e-1 ,
174170 num_workers : int = 6 ,
175171 ** kwargs ,
176172 ) -> None :
177- """
173+ """TransferLearningModel
174+
178175 Args:
179- dl_path: Path where the data will be downloaded
176+ backbone: Name (as in ``torchvision.models``) of the feature extractor
177+ train_bn: Whether the BatchNorm layers should be trainable
178+ milestones: List of two epochs milestones
179+ lr: Initial learning rate
180+ lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
180181 """
181182 super ().__init__ ()
182183 self .backbone = backbone
@@ -269,90 +270,31 @@ def configure_optimizers(self):
269270 scheduler = MultiStepLR (optimizer , milestones = self .milestones , gamma = self .lr_scheduler_gamma )
270271 return [optimizer ], [scheduler ]
271272
272- @staticmethod
273- def add_model_specific_args (parent_parser ):
274- parser = parent_parser .add_argument_group ("TransferLearningModel" )
275- parser .add_argument (
276- "--backbone" ,
277- default = "resnet50" ,
278- type = str ,
279- metavar = "BK" ,
280- help = "Name (as in ``torchvision.models``) of the feature extractor" ,
281- )
282- parser .add_argument (
283- "--epochs" , default = 15 , type = int , metavar = "N" , help = "total number of epochs" , dest = "nb_epochs"
284- )
285- parser .add_argument ("--batch-size" , default = 8 , type = int , metavar = "B" , help = "batch size" , dest = "batch_size" )
286- parser .add_argument ("--gpus" , type = int , default = 0 , help = "number of gpus to use" )
287- parser .add_argument (
288- "--lr" , "--learning-rate" , default = 1e-3 , type = float , metavar = "LR" , help = "initial learning rate" , dest = "lr"
289- )
290- parser .add_argument (
291- "--lr-scheduler-gamma" ,
292- default = 1e-1 ,
293- type = float ,
294- metavar = "LRG" ,
295- help = "Factor by which the learning rate is reduced at each milestone" ,
296- dest = "lr_scheduler_gamma" ,
297- )
298- parser .add_argument (
299- "--train-bn" ,
300- default = False ,
301- type = bool ,
302- metavar = "TB" ,
303- help = "Whether the BatchNorm layers should be trainable" ,
304- dest = "train_bn" ,
305- )
306- parser .add_argument (
307- "--milestones" , default = [2 , 4 ], type = list , metavar = "M" , help = "List of two epochs milestones"
308- )
309- return parent_parser
310-
311-
312- def main (args : argparse .Namespace ) -> None :
313- """Train the model.
314-
315- Args:
316- args: Model hyper-parameters
317-
318- Note:
319- For the sake of the example, the images dataset will be downloaded
320- to a temporary directory.
321- """
322273
323- datamodule = CatDogImageDataModule (
324- dl_path = os .path .join (args .root_data_path , 'data' ), batch_size = args .batch_size , num_workers = args .num_workers
325- )
326- model = TransferLearningModel (** vars (args ))
327- finetuning_callback = MilestonesFinetuning (milestones = args .milestones )
274+ class MyLightningCLI (LightningCLI ):
328275
329- trainer = pl .Trainer (
330- weights_summary = None ,
331- progress_bar_refresh_rate = 1 ,
332- num_sanity_val_steps = 0 ,
333- gpus = args .gpus ,
334- max_epochs = args .nb_epochs ,
335- callbacks = [finetuning_callback ]
336- )
276+ def add_arguments_to_parser (self , parser ):
277+ parser .add_class_arguments (MilestonesFinetuning , 'finetuning' )
278+ parser .link_arguments ('data.batch_size' , 'model.batch_size' )
279+ parser .link_arguments ('finetuning.milestones' , 'model.milestones' )
280+ parser .link_arguments ('finetuning.train_bn' , 'model.train_bn' )
281+ parser .set_defaults ({
282+ 'trainer.max_epochs' : 15 ,
283+ 'trainer.weights_summary' : None ,
284+ 'trainer.progress_bar_refresh_rate' : 1 ,
285+ 'trainer.num_sanity_val_steps' : 0 ,
286+ })
337287
338- trainer .fit (model , datamodule = datamodule )
288+ def instantiate_trainer (self ):
289+ finetuning_callback = MilestonesFinetuning (** self .config_init ['finetuning' ])
290+ self .trainer_defaults ['callbacks' ] = [finetuning_callback ]
291+ super ().instantiate_trainer ()
339292
340293
341- def get_args () -> argparse .Namespace :
342- parent_parser = argparse .ArgumentParser (add_help = False )
343- parent_parser .add_argument (
344- "--root-data-path" ,
345- metavar = "DIR" ,
346- type = str ,
347- default = Path .cwd ().as_posix (),
348- help = "Root directory where to download the data" ,
349- dest = "root_data_path" ,
350- )
351- parser = TransferLearningModel .add_model_specific_args (parent_parser )
352- parser = CatDogImageDataModule .add_argparse_args (parser )
353- return parser .parse_args ()
294+ def cli_main ():
295+ MyLightningCLI (TransferLearningModel , CatDogImageDataModule , seed_everything_default = 1234 )
354296
355297
356298if __name__ == "__main__" :
357299 cli_lightning_logo ()
358- main ( get_args () )
300+ cli_main ( )
0 commit comments