diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 9d44d65db7e70..b11d505c502ad 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -14,18 +14,14 @@ def __init__( self, encoder_layers: int = 12, - decoder_layers: List[int] = [2, 4] + decoder_layers: List[int] = [2, 4], + batch_size: int = 8, ): - """Example encoder-decoder model - - Args: - encoder_layers: Number of layers for the encoder - decoder_layers: Number of layers for each decoder block - """ pass class MyDataModule(LightningDataModule): - pass + def __init__(self, batch_size: int = 8): + pass def send_email(address, message): pass @@ -119,7 +115,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this information to define its configurable arguments. -.. testcode:: +.. testcode:: mymodel class MyModel(LightningModule): @@ -373,8 +369,46 @@ before and after the execution of fit. The code would be something like: cli = MyLightningCLI(MyModel) Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It -has the same structure as the yaml format as described previously. This means for instance that the parameters used for +has the same structure as the yaml format described previously. This means for instance that the parameters used for instantiating the trainer class can be found in :code:`self.config['trainer']`. -For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be -extended. For further information have a look at the corresponding API reference. +Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the +model and data module depend on a common parameter. For example in some cases both classes require to know the +:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the +parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented +like shown below, the :code:`batch_size` only has to be provided in the :code:`data` section of the config. + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + class MyLightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + parser.link_arguments('data.batch_size', 'model.batch_size') + + cli = MyLightningCLI(MyModel, MyDataModule) + +The linking of arguments is observed in the help of the tool, which for this example would look like: + +.. code-block:: bash + + $ python trainer.py --help + ... + --data.batch_size BATCH_SIZE + Number of samples in a batch (type: int, default: 8) + + Linked arguments: + model.batch_size <-- data.batch_size + Number of samples in a batch (type: int) + +.. tip:: + + The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes + multiple settings as input. For more details have a look at the API of `link_arguments + `_. + +.. tip:: + + Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other + methods that can be extended to customize a CLI. diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index a358d300ca055..ec191defc5e74 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -37,9 +37,8 @@ Note: See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html """ -import argparse + import logging -import os from pathlib import Path from typing import Union @@ -59,6 +58,7 @@ from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.cli import LightningCLI log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" @@ -93,10 +93,17 @@ class CatDogImageDataModule(LightningDataModule): def __init__( self, - dl_path: Union[str, Path], + dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8, ): + """CatDogImageDataModule + + Args: + dl_path: root directory where to download the data + num_workers: number of CPU workers + batch_size: number of sample in a batch + """ super().__init__() self._dl_path = dl_path @@ -146,17 +153,6 @@ def val_dataloader(self): log.info("Validation data loaded.") return self.__dataloader(train=False) - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("CatDogImageDataModule") - parser.add_argument( - "--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers" - ) - parser.add_argument( - "--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size" - ) - return parent_parser - # --- Pytorch-lightning module --- @@ -166,17 +162,22 @@ class TransferLearningModel(pl.LightningModule): def __init__( self, backbone: str = "resnet50", - train_bn: bool = True, - milestones: tuple = (5, 10), + train_bn: bool = False, + milestones: tuple = (2, 4), batch_size: int = 32, - lr: float = 1e-2, + lr: float = 1e-3, lr_scheduler_gamma: float = 1e-1, num_workers: int = 6, **kwargs, ) -> None: - """ + """TransferLearningModel + Args: - dl_path: Path where the data will be downloaded + backbone: Name (as in ``torchvision.models``) of the feature extractor + train_bn: Whether the BatchNorm layers should be trainable + milestones: List of two epochs milestones + lr: Initial learning rate + lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone """ super().__init__() self.backbone = backbone @@ -269,90 +270,31 @@ def configure_optimizers(self): scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma) return [optimizer], [scheduler] - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("TransferLearningModel") - parser.add_argument( - "--backbone", - default="resnet50", - type=str, - metavar="BK", - help="Name (as in ``torchvision.models``) of the feature extractor", - ) - parser.add_argument( - "--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs" - ) - parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size") - parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use") - parser.add_argument( - "--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr" - ) - parser.add_argument( - "--lr-scheduler-gamma", - default=1e-1, - type=float, - metavar="LRG", - help="Factor by which the learning rate is reduced at each milestone", - dest="lr_scheduler_gamma", - ) - parser.add_argument( - "--train-bn", - default=False, - type=bool, - metavar="TB", - help="Whether the BatchNorm layers should be trainable", - dest="train_bn", - ) - parser.add_argument( - "--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones" - ) - return parent_parser - - -def main(args: argparse.Namespace) -> None: - """Train the model. - - Args: - args: Model hyper-parameters - - Note: - For the sake of the example, the images dataset will be downloaded - to a temporary directory. - """ - datamodule = CatDogImageDataModule( - dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers - ) - model = TransferLearningModel(**vars(args)) - finetuning_callback = MilestonesFinetuning(milestones=args.milestones) +class MyLightningCLI(LightningCLI): - trainer = pl.Trainer( - weights_summary=None, - progress_bar_refresh_rate=1, - num_sanity_val_steps=0, - gpus=args.gpus, - max_epochs=args.nb_epochs, - callbacks=[finetuning_callback] - ) + def add_arguments_to_parser(self, parser): + parser.add_class_arguments(MilestonesFinetuning, 'finetuning') + parser.link_arguments('data.batch_size', 'model.batch_size') + parser.link_arguments('finetuning.milestones', 'model.milestones') + parser.link_arguments('finetuning.train_bn', 'model.train_bn') + parser.set_defaults({ + 'trainer.max_epochs': 15, + 'trainer.weights_summary': None, + 'trainer.progress_bar_refresh_rate': 1, + 'trainer.num_sanity_val_steps': 0, + }) - trainer.fit(model, datamodule=datamodule) + def instantiate_trainer(self): + finetuning_callback = MilestonesFinetuning(**self.config_init['finetuning']) + self.trainer_defaults['callbacks'] = [finetuning_callback] + super().instantiate_trainer() -def get_args() -> argparse.Namespace: - parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser.add_argument( - "--root-data-path", - metavar="DIR", - type=str, - default=Path.cwd().as_posix(), - help="Root directory where to download the data", - dest="root_data_path", - ) - parser = TransferLearningModel.add_model_specific_args(parent_parser) - parser = CatDogImageDataModule.add_argparse_args(parser) - return parser.parse_args() +def cli_main(): + MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234) if __name__ == "__main__": cli_lightning_logo() - main(get_args()) + cli_main() diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 61348567e5577..413b06f39f7a6 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -161,9 +161,8 @@ def __init__( self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse}) self.init_parser() - self.add_arguments_to_parser(self.parser) self.add_core_arguments_to_parser() - self.before_parse_arguments(self.parser) + self.add_arguments_to_parser(self.parser) self.parse_arguments() if self.config['seed_everything'] is not None: seed_everything(self.config['seed_everything']) @@ -178,13 +177,6 @@ def init_parser(self) -> None: """Method that instantiates the argument parser""" self.parser = LightningArgumentParser(**self.parser_kwargs) - def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """Implement to add extra arguments to parser - - Args: - parser: The argument parser object to which arguments should be added - """ - def add_core_arguments_to_parser(self) -> None: """Adds arguments from the core classes to the parser""" self.parser.add_argument( @@ -200,11 +192,11 @@ def add_core_arguments_to_parser(self) -> None: if self.datamodule_class is not None: self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) - def before_parse_arguments(self, parser: LightningArgumentParser) -> None: - """Implement to run some code before parsing arguments + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Implement to add extra arguments to parser or link arguments Args: - parser: The argument parser object that will be used to parse + parser: The argument parser object to which arguments can be added """ def parse_arguments(self) -> None: diff --git a/requirements/extra.txt b/requirements/extra.txt index 89b16b1095891..db2e66540eef1 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.11.0 +jsonargparse[signatures]>=3.11.1