Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
def __init__(
self,
encoder_layers: int = 12,
decoder_layers: List[int] = [2, 4]
decoder_layers: List[int] = [2, 4],
batch_size: int = 8,
):
"""Example encoder-decoder model

Args:
encoder_layers: Number of layers for the encoder
decoder_layers: Number of layers for each decoder block
"""
pass

class MyDataModule(LightningDataModule):
pass
def __init__(self, batch_size: int = 8):
pass

def send_email(address, message):
pass
Expand Down Expand Up @@ -119,7 +115,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
information to define its configurable arguments.

.. testcode::
.. testcode:: mymodel

class MyModel(LightningModule):

Expand Down Expand Up @@ -373,8 +369,46 @@ before and after the execution of fit. The code would be something like:
cli = MyLightningCLI(MyModel)

Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It
has the same structure as the yaml format as described previously. This means for instance that the parameters used for
has the same structure as the yaml format described previously. This means for instance that the parameters used for
instantiating the trainer class can be found in :code:`self.config['trainer']`.

For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be
extended. For further information have a look at the corresponding API reference.
Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
model and data module depend on a common parameter. For example in some cases both classes require to know the
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented
like shown below, the :code:`batch_size` only has to be provided in the :code:`data` section of the config.

.. testcode::

from pytorch_lightning.utilities.cli import LightningCLI

class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.link_arguments('data.batch_size', 'model.batch_size')

cli = MyLightningCLI(MyModel, MyDataModule)

The linking of arguments is observed in the help of the tool, which for this example would look like:

.. code-block:: bash

$ python trainer.py --help
...
--data.batch_size BATCH_SIZE
Number of samples in a batch (type: int, default: 8)

Linked arguments:
model.batch_size <-- data.batch_size
Number of samples in a batch (type: int)

.. tip::

The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
multiple settings as input. For more details have a look at the API of `link_arguments
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.link_arguments>`_.

.. tip::

Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
methods that can be extended to customize a CLI.
136 changes: 39 additions & 97 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
Note:
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
"""
import argparse

import logging
import os
from pathlib import Path
from typing import Union

Expand All @@ -59,6 +58,7 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.cli import LightningCLI

log = logging.getLogger(__name__)
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
Expand Down Expand Up @@ -93,10 +93,17 @@ class CatDogImageDataModule(LightningDataModule):

def __init__(
self,
dl_path: Union[str, Path],
dl_path: Union[str, Path] = "data",
num_workers: int = 0,
batch_size: int = 8,
):
"""CatDogImageDataModule

Args:
dl_path: root directory where to download the data
num_workers: number of CPU workers
batch_size: number of sample in a batch
"""
super().__init__()

self._dl_path = dl_path
Expand Down Expand Up @@ -146,17 +153,6 @@ def val_dataloader(self):
log.info("Validation data loaded.")
return self.__dataloader(train=False)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("CatDogImageDataModule")
parser.add_argument(
"--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument(
"--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size"
)
return parent_parser


# --- Pytorch-lightning module ---

Expand All @@ -166,17 +162,22 @@ class TransferLearningModel(pl.LightningModule):
def __init__(
self,
backbone: str = "resnet50",
train_bn: bool = True,
milestones: tuple = (5, 10),
train_bn: bool = False,
milestones: tuple = (2, 4),
batch_size: int = 32,
lr: float = 1e-2,
lr: float = 1e-3,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
**kwargs,
) -> None:
"""
"""TransferLearningModel

Args:
dl_path: Path where the data will be downloaded
backbone: Name (as in ``torchvision.models``) of the feature extractor
train_bn: Whether the BatchNorm layers should be trainable
milestones: List of two epochs milestones
lr: Initial learning rate
lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
"""
super().__init__()
self.backbone = backbone
Expand Down Expand Up @@ -269,90 +270,31 @@ def configure_optimizers(self):
scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)
return [optimizer], [scheduler]

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("TransferLearningModel")
parser.add_argument(
"--backbone",
default="resnet50",
type=str,
metavar="BK",
help="Name (as in ``torchvision.models``) of the feature extractor",
)
parser.add_argument(
"--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
)
parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size")
parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use")
parser.add_argument(
"--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr"
)
parser.add_argument(
"--lr-scheduler-gamma",
default=1e-1,
type=float,
metavar="LRG",
help="Factor by which the learning rate is reduced at each milestone",
dest="lr_scheduler_gamma",
)
parser.add_argument(
"--train-bn",
default=False,
type=bool,
metavar="TB",
help="Whether the BatchNorm layers should be trainable",
dest="train_bn",
)
parser.add_argument(
"--milestones", default=[2, 4], type=list, metavar="M", help="List of two epochs milestones"
)
return parent_parser


def main(args: argparse.Namespace) -> None:
"""Train the model.

Args:
args: Model hyper-parameters

Note:
For the sake of the example, the images dataset will be downloaded
to a temporary directory.
"""

datamodule = CatDogImageDataModule(
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
)
model = TransferLearningModel(**vars(args))
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)
class MyLightningCLI(LightningCLI):

trainer = pl.Trainer(
weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
)
def add_arguments_to_parser(self, parser):
parser.add_class_arguments(MilestonesFinetuning, 'finetuning')
parser.link_arguments('data.batch_size', 'model.batch_size')
parser.link_arguments('finetuning.milestones', 'model.milestones')
parser.link_arguments('finetuning.train_bn', 'model.train_bn')
parser.set_defaults({
'trainer.max_epochs': 15,
'trainer.weights_summary': None,
'trainer.progress_bar_refresh_rate': 1,
'trainer.num_sanity_val_steps': 0,
})

trainer.fit(model, datamodule=datamodule)
def instantiate_trainer(self):
finetuning_callback = MilestonesFinetuning(**self.config_init['finetuning'])
self.trainer_defaults['callbacks'] = [finetuning_callback]
super().instantiate_trainer()


def get_args() -> argparse.Namespace:
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument(
"--root-data-path",
metavar="DIR",
type=str,
default=Path.cwd().as_posix(),
help="Root directory where to download the data",
dest="root_data_path",
)
parser = TransferLearningModel.add_model_specific_args(parent_parser)
parser = CatDogImageDataModule.add_argparse_args(parser)
return parser.parse_args()
def cli_main():
MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)


if __name__ == "__main__":
cli_lightning_logo()
main(get_args())
cli_main()
16 changes: 4 additions & 12 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ def __init__(
self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse})

self.init_parser()
self.add_arguments_to_parser(self.parser)
self.add_core_arguments_to_parser()
self.before_parse_arguments(self.parser)
self.add_arguments_to_parser(self.parser)
self.parse_arguments()
if self.config['seed_everything'] is not None:
seed_everything(self.config['seed_everything'])
Expand All @@ -178,13 +177,6 @@ def init_parser(self) -> None:
"""Method that instantiates the argument parser"""
self.parser = LightningArgumentParser(**self.parser_kwargs)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Implement to add extra arguments to parser

Args:
parser: The argument parser object to which arguments should be added
"""

def add_core_arguments_to_parser(self) -> None:
"""Adds arguments from the core classes to the parser"""
self.parser.add_argument(
Expand All @@ -200,11 +192,11 @@ def add_core_arguments_to_parser(self) -> None:
if self.datamodule_class is not None:
self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data)

def before_parse_arguments(self, parser: LightningArgumentParser) -> None:
"""Implement to run some code before parsing arguments
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Implement to add extra arguments to parser or link arguments

Args:
parser: The argument parser object that will be used to parse
parser: The argument parser object to which arguments can be added
"""

def parse_arguments(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.11.0
jsonargparse[signatures]>=3.11.1