diff --git a/CHANGELOG.md b/CHANGELOG.md index b24229b8eef71..8393c6a493e12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667)) +- Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli. ([#4492](https://github.com/PyTorchLightning/pytorch-lightning/pull/4492)) + + - Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417)) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index e9520dea8045f..cbe11defc3a06 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -93,5 +93,6 @@ Utilities API :toctree: api :nosignatures: + cli argparse_utils seed diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst new file mode 100644 index 0000000000000..6d8ae6701fd40 --- /dev/null +++ b/docs/source/common/lightning_cli.rst @@ -0,0 +1,316 @@ +.. testsetup:: * + :skipif: not _JSONARGPARSE_AVAILABLE + + from unittest import mock + from typing import List + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning.utilities.cli import LightningCLI + + original_fit = LightningCLI.fit + LightningCLI.fit = lambda self: None + + class MyModel(LightningModule): + def __init__( + self, + encoder_layers: int = 12, + decoder_layers: List[int] = [2, 4] + ): + """Example encoder-decoder model + + Args: + encoder_layers: Number of layers for the encoder + decoder_layers: Number of layers for each decoder block + """ + pass + + class MyDataModule(LightningDataModule): + pass + + def send_email(address, message): + pass + + MyModelBaseClass = MyModel + MyDataModuleBaseClass = MyDataModule + + mock_argv = mock.patch("sys.argv", ["any.py"]) + mock_argv.start() + +.. testcleanup:: * + + LightningCLI.fit = original_fit + mock_argv.stop() + + +Lightning CLI and config files +------------------------------ + +Another source of boilerplate code that Lightning can help to reduce is in the implementation of training command line +tools. Furthermore, it provides a standardized way to configure trainings using a single file that includes settings for +:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +:class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The full configuration is automatically saved +in the log directory. This has the benefit of greatly simplifying the reproducibility of experiments. + +The main requirement for user extended classes to be made configurable is that all relevant init arguments must have +type hints. This is not a very demanding requirement since it is good practice to do anyway. As a bonus if the arguments +are described in the docstrings, then the help of the training tool will display them. + +.. warning:: ``LightningCLI`` is in beta and subject to change. + +---------- + + +LightningCLI +^^^^^^^^^^^^ + +The implementation of training command line tools is done via the :class:`~pytorch_lightning.utilities.cli.LightningCLI` +class. The minimal installation of pytorch-lightning does not include this support. To enable it either install +lightning with the :code:`all` extras require or install the package :code:`jsonargparse[signatures]`. + +The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required +:code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as: + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + cli = LightningCLI(MyModel) + +The help of the tool describing all configurable options and default values can be shown by running :code:`python +trainer.py --help`. Default options can be changed by providing individual command line arguments. However, it is better +practice to create a configuration file and provide this to the tool. A way to do this would be: + +.. code-block:: bash + + # Dump default configuration to have as reference + python trainer.py --print_config > default_config.yaml + # Create config including only options to modify + nano config.yaml + # Run training using created configuration + python trainer.py --config config.yaml + +The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line +and config file options, instantiating the classes, setting up a callback to save the config in the log directory and +finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit, +i.e., :code:`cli.fit_result`. + +After multiple trainings with different configurations, each run will have in its respective log directory a +:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each +particular run, and also could be used to trivially reproduce a training, e.g.: + +.. code-block:: bash + + python trainer.py --config lightning_logs/version_7/config.yaml + +If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class is required, the trainer tool just +needs a small modification as follows: + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + cli = LightningCLI(MyModel, MyDataModule) + +The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the +docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this +information to define its configurable arguments. + +.. code-block:: python + + class MyModel(LightningModule): + + def __init__( + self, + encoder_layers: int = 12, + decoder_layers: List[int] = [2, 4] + ): + """Example encoder-decoder model + + Args: + encoder_layers: Number of layers for the encoder + decoder_layers: Number of layers for each decoder block + """ + ... + +With this model class, the help of the trainer tool would look as follows: + +.. code-block:: bash + + $ python trainer.py --help + usage: trainer.py [-h] [--print_config] [--config CONFIG] + [--trainer.logger LOGGER] + ... + + pytorch-lightning trainer command line tool + + optional arguments: + -h, --help show this help message and exit + --print_config print configuration and exit + --config CONFIG Path to a configuration file in json or yaml format. + (default: null) + + Customize every aspect of training via flags: + ... + --trainer.max_epochs MAX_EPOCHS + Stop training once this number of epochs is reached. + (type: int, default: 1000) + --trainer.min_epochs MIN_EPOCHS + Force training for at least these many epochs (type: int, + default: 1) + ... + + Example encoder-decoder model: + --model.encoder_layers ENCODER_LAYERS + Number of layers for the encoder (type: int, default: 12) + --model.decoder_layers DECODER_LAYERS + Number of layers for each decoder block (type: List[int], + default: [2, 4]) + +The default configuration that option :code:`--print_config` gives is in yaml format and for the example above would +look as follows: + +.. code-block:: bash + + $ python trainer.py --print_config + model: + decoder_layers: + - 2 + - 4 + encoder_layers: 12 + trainer: + accelerator: null + accumulate_grad_batches: 1 + amp_backend: native + amp_level: O2 + ... + +Note that there is a section for each class (model and trainer) including all the init parameters of the class. This +grouping is also used in the formatting of the help shown previously. + + +Trainer Callbacks and arguments with class type +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A very important argument of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class is the :code:`callbacks`. In +contrast to other more simple arguments which just require numbers or strings, :code:`callbacks` expects a list of +instances of subclasses of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of argument in a config +file, each callback must be given as a dictionary including a :code:`class_path` entry with an import path of the class, +and optionally an :code:`init_args` entry with arguments required to instantiate it. Therefore, a simple configuration +file example that defines a couple of callbacks is the following: + +.. code-block:: yaml + + trainer: + callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + init_args: + ... + +Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +:class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured +the same way using :code:`class_path` and :code:`init_args`. + + +Multiple models and/or datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and +datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for +multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is +specified by an import path and init arguments. For example, with a tool implemented as: + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + cli = LightningCLI( + MyModelBaseClass, + MyDataModuleBaseClass, + subclass_mode_model=True, + subclass_mode_data=True + ) + +A possible config file could be as follows: + +.. code-block:: yaml + + model: + class_path: mycode.mymodels.MyModel + init_args: + decoder_layers: + - 2 + - 4 + encoder_layers: 12 + data: + class_path: mycode.mydatamodules.MyDataModule + init_args: + ... + trainer: + callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + ... + +Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of +:code:`MyDataModuleBaseClass`. + + +Customizing LightningCLI +^^^^^^^^^^^^^^^^^^^^^^^^ + +The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some +things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to +instantiate the trainer and configuration parser. + +Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be +extended to customize different parts of the command line tool. The argument parser class used by +:class:`~pytorch_lightning.utilities.cli.LightningCLI` is +:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus +adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to +add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring +parameters to have type hints. For more details about this please refer to the `respective documentation +`_. + +The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include +more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The +:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before +and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic example for these would be to send an email +before and after the execution of fit. The code would be something like: + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + class MyLightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + parser.add_argument('--notification_email', default='will@email.com') + + def before_fit(self): + send_email( + address=self.config['notification_email'], + message='trainer.fit starting' + ) + + def after_fit(self): + send_email( + address=self.config['notification_email'], + message='trainer.fit finished' + ) + + cli = MyLightningCLI(MyModel) + +Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It +has the same structure as the yaml format as described previously. This means for instance that the parameters used for +instantiating the trainer class can be found in :code:`self.config['trainer']`. + +For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be +extended. For further information have a look at the corresponding API reference. diff --git a/docs/source/conf.py b/docs/source/conf.py index 1c1f3be8a636a..68a33c709ef91 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -389,6 +389,6 @@ def package_list_from_file(file): _TORCHVISION_AVAILABLE, _module_available, ) -TORCHVISION_AVAILABLE = _module_available("torchvision") +_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") """ coverage_skip_undoc_in_source = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 1432badf2038f..030ab6d70aa3e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ PyTorch Lightning Documentation common/early_stopping common/fast_training common/hyperparameters + common/lightning_cli advanced/lr_finder advanced/multi_gpu advanced/multiple_loaders diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 315e3c60c0557..2f01363086478 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -199,13 +199,7 @@ def slurm_job_id(self) -> Optional[int]: @classmethod def default_attributes(cls) -> dict: init_signature = inspect.signature(cls) - - args = {} - for param_name in init_signature.parameters: - value = init_signature.parameters[param_name].default - args[param_name] = value - - return args + return {k: v.default for k, v in init_signature.parameters.items()} @classmethod def get_deprecated_arg_names(cls) -> List: diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py new file mode 100644 index 0000000000000..33c424829606a --- /dev/null +++ b/pytorch_lightning/utilities/cli.py @@ -0,0 +1,260 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from argparse import Namespace +from typing import Any, Dict, Type, Union + +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities import _module_available + +_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") +if _JSONARGPARSE_AVAILABLE: + from jsonargparse import ActionConfigFile, ArgumentParser +else: + ArgumentParser = object + + +class LightningArgumentParser(ArgumentParser): + """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" + + def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None: + """Initialize argument parser that supports configuration file input + + For full details of accepted arguments see `ArgumentParser.__init__ + `_. + """ + if not _JSONARGPARSE_AVAILABLE: + raise ModuleNotFoundError( + '`jsonargparse` is not installed but it is required for the CLI.' + ' Install it with `pip install jsonargparse[signatures]`.' + ) + super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + self.add_argument( + '--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.' + ) + + def add_lightning_class_args( + self, + lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], + nested_key: str, + subclass_mode: bool = False + ) -> None: + """ + Adds arguments from a lightning class to a nested key of the parser + + Args: + lightning_class: Any subclass of {Trainer,LightningModule,LightningDataModule}. + nested_key: Name of the nested namespace to store arguments. + subclass_mode: Whether allow any subclass of the given class. + """ + assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule)) + if subclass_mode: + return self.add_subclass_arguments(lightning_class, nested_key) + return self.add_class_arguments(lightning_class, nested_key) + + +class SaveConfigCallback(Callback): + """Saves a LightningCLI config to the log_dir when training starts""" + + def __init__( + self, + parser: LightningArgumentParser, + config: Union[Namespace, Dict[str, Any]], + config_filename: str = 'config.yaml' + ) -> None: + self.parser = parser + self.config = config + self.config_filename = config_filename + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + log_dir = trainer.log_dir or trainer.default_root_dir + config_path = os.path.join(log_dir, self.config_filename) + self.parser.save(self.config, config_path, skip_none=False) + + +class LightningCLI: + """Implementation of a configurable command line tool for pytorch-lightning""" + + def __init__( + self, + model_class: Type[LightningModule], + datamodule_class: Type[LightningDataModule] = None, + save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, + trainer_class: Type[Trainer] = Trainer, + trainer_defaults: Dict[str, Any] = None, + description: str = 'pytorch-lightning trainer command line tool', + env_prefix: str = 'PL', + env_parse: bool = False, + parser_kwargs: Dict[str, Any] = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False + ) -> None: + """ + Receives as input pytorch-lightning classes, which are instantiated + using a parsed configuration file and/or command line args and then runs + trainer.fit. Parsing of configuration from environment variables can + be enabled by setting ``env_parse=True``. A full configuration yaml would + be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from + variables named for example ``PL_TRAINER__MAX_EPOCHS``. + + Example, first implement the ``trainer.py`` tool as:: + + from mymodels import MyModel + from pytorch_lightning.utilities.cli import LightningCLI + LightningCLI(MyModel) + + Then in a shell, run the tool with the desired configuration:: + + $ python trainer.py --print_config > config.yaml + $ nano config.yaml # modify the config as desired + $ python trainer.py --cfg config.yaml + + .. warning:: ``LightningCLI`` is in beta and subject to change. + + Args: + model_class: The LightningModule class to train on. + datamodule_class: An optional LightningDataModule class. + save_config_callback: A callback class to save the training config. + trainer_class: An optional extension of the Trainer class. + trainer_defaults: Set to override Trainer defaults or add persistent callbacks. + description: Description of the tool shown when running --help. + env_prefix: Prefix for environment variables. + env_parse: Whether environment variable parsing is enabled. + parser_kwargs: Additional arguments to instantiate LightningArgumentParser. + subclass_mode_model: Whether model can be any `subclass + `_ + of the given class. + subclass_mode_data: Whether datamodule can be any `subclass + `_ + of the given class. + """ + assert issubclass(trainer_class, Trainer) + assert issubclass(model_class, LightningModule) + if datamodule_class is not None: + assert issubclass(datamodule_class, LightningDataModule) + self.model_class = model_class + self.datamodule_class = datamodule_class + self.save_config_callback = save_config_callback + self.trainer_class = trainer_class + self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults + self.subclass_mode_model = subclass_mode_model + self.subclass_mode_data = subclass_mode_data + self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs + self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse}) + + self.init_parser() + self.add_arguments_to_parser(self.parser) + self.add_core_arguments_to_parser() + self.before_parse_arguments(self.parser) + self.parse_arguments() + self.before_instantiate_classes() + self.instantiate_classes() + self.prepare_fit_kwargs() + self.before_fit() + self.fit() + self.after_fit() + + def init_parser(self) -> None: + """Method that instantiates the argument parser""" + self.parser = LightningArgumentParser(**self.parser_kwargs) + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Implement to add extra arguments to parser + + Args: + parser: The argument parser object to which arguments should be added + """ + pass + + def add_core_arguments_to_parser(self) -> None: + """Adds arguments from the core classes to the parser""" + self.parser.add_lightning_class_args(self.trainer_class, 'trainer') + trainer_defaults = {'trainer.' + k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} + self.parser.set_defaults(trainer_defaults) + self.parser.add_lightning_class_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) + if self.datamodule_class is not None: + self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) + + def before_parse_arguments(self, parser: LightningArgumentParser) -> None: + """Implement to run some code before parsing arguments + + Args: + parser: The argument parser object that will be used to parse + """ + pass + + def parse_arguments(self) -> None: + """Parses command line arguments and stores it in self.config""" + self.config = self.parser.parse_args() + + def before_instantiate_classes(self) -> None: + """Implement to run some code before instantiating the classes""" + pass + + def instantiate_classes(self) -> None: + """Instantiates the classes using settings from self.config""" + self.config_init = self.parser.instantiate_subclasses(self.config) + self.instantiate_datamodule() + self.instantiate_model() + self.instantiate_trainer() + + def instantiate_datamodule(self) -> None: + """Instantiates the datamodule using self.config_init['data'] if given""" + if self.datamodule_class is None: + self.datamodule = None + elif self.subclass_mode_data: + self.datamodule = self.config_init['data'] + else: + self.datamodule = self.datamodule_class(**self.config_init.get('data', {})) + + def instantiate_model(self) -> None: + """Instantiates the model using self.config_init['model']""" + if self.subclass_mode_model: + self.model = self.config_init['model'] + else: + self.model = self.model_class(**self.config_init.get('model', {})) + + def instantiate_trainer(self) -> None: + """Instantiates the trainer using self.config_init['trainer']""" + if self.config_init['trainer'].get('callbacks') is None: + self.config_init['trainer']['callbacks'] = [] + if 'callbacks' in self.trainer_defaults: + if isinstance(self.trainer_defaults['callbacks'], list): + self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) + else: + self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) + if self.save_config_callback is not None: + self.config_init['trainer']['callbacks'].append(self.save_config_callback(self.parser, self.config)) + self.trainer = self.trainer_class(**self.config_init['trainer']) + + def prepare_fit_kwargs(self) -> None: + """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" + self.fit_kwargs = {'model': self.model} + if self.datamodule is not None: + self.fit_kwargs['datamodule'] = self.datamodule + + def before_fit(self) -> None: + """Implement to run some code before fit is started""" + pass + + def fit(self) -> None: + """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" + self.fit_result = self.trainer.fit(**self.fit_kwargs) + + def after_fit(self) -> None: + """Implement to run some code after fit has finished""" + pass diff --git a/requirements/extra.txt b/requirements/extra.txt index cee1fd0eb07e1..46a726fe05c43 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -9,4 +9,5 @@ onnxruntime>=1.3.0 hydra-core>=1.0 # todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip +jsonargparse[signatures]>=3.3.1 deepspeed>=0.3.13 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py new file mode 100644 index 0000000000000..9b2c825c7aeb8 --- /dev/null +++ b/tests/utilities/test_cli.py @@ -0,0 +1,329 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import pickle +import sys +from argparse import Namespace +from unittest import mock + +import pytest +import yaml + +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback +from tests.helpers import BoringDataModule, BoringModel + + +@mock.patch('argparse.ArgumentParser.parse_args') +def test_default_args(mock_argparse, tmpdir): + """Tests default argument parser for Trainer""" + mock_argparse.return_value = Namespace(**Trainer.default_attributes()) + + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + args = parser.parse_args([]) + + args.max_epochs = 5 + trainer = Trainer.from_argparse_args(args) + + assert isinstance(trainer, Trainer) + assert trainer.max_epochs == 5 + + +@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) +def test_add_argparse_args_redefined(cli_args): + """Redefines some default Trainer arguments via the cli and + tests the Trainer initialization correctness. + """ + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + + args = parser.parse_args(cli_args) + + # make sure we can pickle args + pickle.dumps(args) + + # Check few deprecated args are not in namespace: + for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'): + assert depr_name not in args + + trainer = Trainer.from_argparse_args(args=args) + pickle.dumps(trainer) + + assert isinstance(trainer, Trainer) + + +@pytest.mark.parametrize('cli_args', [['--callbacks=1', '--logger'], ['--foo', '--bar=1']]) +def test_add_argparse_args_redefined_error(cli_args, monkeypatch): + """Asserts error raised in case of passing not default cli arguments.""" + + class _UnkArgError(Exception): + pass + + def _raise(): + raise _UnkArgError + + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + + monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True) + + with pytest.raises(_UnkArgError): + parser.parse_args(cli_args) + + +@pytest.mark.parametrize( + ['cli_args', 'expected'], + [ + ('--auto_lr_find=True --auto_scale_batch_size=power', dict(auto_lr_find=True, auto_scale_batch_size='power')), + ( + '--auto_lr_find any_string --auto_scale_batch_size ON', + dict(auto_lr_find='any_string', auto_scale_batch_size=True), + ), + ('--auto_lr_find=Yes --auto_scale_batch_size=On', dict(auto_lr_find=True, auto_scale_batch_size=True)), + ('--auto_lr_find Off --auto_scale_batch_size No', dict(auto_lr_find=False, auto_scale_batch_size=False)), + ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', dict(auto_lr_find=True, auto_scale_batch_size=False)), + ('--tpu_cores=8', dict(tpu_cores=8)), + ('--tpu_cores=1,', dict(tpu_cores='1,')), + ('--limit_train_batches=100', dict(limit_train_batches=100)), + ('--limit_train_batches 0.8', dict(limit_train_batches=0.8)), + ('--weights_summary=null', dict(weights_summary=None)), + ( + "", + dict( + # These parameters are marked as Optional[...] in Trainer.__init__, + # with None as default. They should not be changed by the argparse + # interface. + min_steps=None, + max_steps=None, + log_gpu_memory=None, + distributed_backend=None, + weights_save_path=None, + truncated_bptt_steps=None, + resume_from_checkpoint=None, + profiler=None + ), + ), + ], +) +def test_parse_args_parsing(cli_args, expected): + """Test parsing simple types and None optionals not modified.""" + cli_args = cli_args.split(' ') if cli_args else [] + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + if 'tpu_cores' not in expected or _TPU_AVAILABLE: + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize( + ['cli_args', 'expected', 'instantiate'], + [ + (['--gpus', '[0, 2]'], dict(gpus=[0, 2]), False), + (['--tpu_cores=[1,3]'], dict(tpu_cores=[1, 3]), False), + (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={ + 5: 3, + 10: 20 + }), True), + ], +) +def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): + """Test parsing complex types.""" + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + if instantiate: + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize( + ['cli_args', 'expected_gpu'], + [ + ('--gpus 1', [0]), + ('--gpus 0,', [0]), + ('--gpus 0,1', [0, 1]), + ], +) +def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): + """Test parsing of gpus and instantiation of Trainer.""" + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + cli_args = cli_args.split(' ') if cli_args else [] + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + trainer = Trainer.from_argparse_args(args) + assert trainer.data_parallel_device_ids == expected_gpu + + +@pytest.mark.skipif( + sys.version_info < (3, 7), + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", +) +@pytest.mark.parametrize( + ['cli_args', 'extra_args'], + [ + ({}, {}), + (dict(logger=False), {}), + (dict(logger=False), dict(logger=True)), + (dict(logger=False), dict(checkpoint_callback=True)), + ], +) +def test_init_from_argparse_args(cli_args, extra_args): + unknown_args = dict(unknown_arg=0) + + # unkown args in the argparser/namespace should be ignored + with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init: + trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) + expected = dict(cli_args) + expected.update(extra_args) # extra args should override any cli arg + init.assert_called_with(trainer, **expected) + + # passing in unknown manual args should throw an error + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): + Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) + + +@pytest.mark.parametrize(['cli_args', 'expected_model', 'expected_trainer'], [( + ['--model.model_param=7', '--trainer.limit_train_batches=100'], + dict(model_param=7), + dict(limit_train_batches=100), +)]) +def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): + """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" + + def fit(trainer, model): + for k, v in model.expected_model.items(): + assert getattr(model, k) == v + for k, v in model.expected_trainer.items(): + assert getattr(trainer, k) == v + save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] + assert len(save_callback) == 1 + save_callback[0].on_train_start(trainer, model) + + def on_train_start(callback, trainer, model): + config_dump = callback.parser.dump(callback.config, skip_none=False) + for k, v in model.expected_model.items(): + assert f' {k}: {v}' in config_dump + for k, v in model.expected_trainer.items(): + assert f' {k}: {v}' in config_dump + trainer.ran_asserts = True + + monkeypatch.setattr(Trainer, 'fit', fit) + monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) + + class TestModel(LightningModule): + + def __init__(self, model_param: int): + super().__init__() + self.model_param = model_param + + TestModel.expected_model = expected_model + TestModel.expected_trainer = expected_trainer + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) + assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts + + +def test_lightning_cli_args_callbacks(tmpdir): + + callbacks = [ + dict( + class_path='pytorch_lightning.callbacks.LearningRateMonitor', + init_args=dict(logging_interval='epoch', log_momentum=True) + ), + dict(class_path='pytorch_lightning.callbacks.ModelCheckpoint', init_args=dict(monitor='NAME')), + ] + + class TestModel(BoringModel): + + def on_fit_start(self): + callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == 'epoch' + assert callback[0].log_momentum is True + callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + assert len(callback) == 1 + assert callback[0].monitor == 'NAME' + self.trainer.ran_asserts = True + + with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']): + cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + + assert cli.trainer.ran_asserts + + +def test_lightning_cli_args(tmpdir): + + cli_args = [ + f'--data.data_dir={tmpdir}', + f'--trainer.default_root_dir={tmpdir}', + '--trainer.max_epochs=1', + '--trainer.weights_summary=null', + ] + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]}) + + assert cli.fit_result == 1 + config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert 'model' not in config and 'model' not in cli.config # no arguments to include + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer'] + + +def test_lightning_cli_config_and_subclass_mode(tmpdir): + + config = dict( + model=dict(class_path='tests.helpers.BoringModel'), + data=dict(class_path='tests.helpers.BoringDataModule', init_args=dict(data_dir=str(tmpdir))), + trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None) + ) + config_path = tmpdir / 'config.yaml' + with open(config_path, 'w') as f: + f.write(yaml.dump(config)) + + with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]): + cli = LightningCLI( + BoringModel, + BoringDataModule, + subclass_mode_model=True, + subclass_mode_data=True, + trainer_defaults={'callbacks': LearningRateMonitor()} + ) + + assert cli.fit_result == 1 + config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert config['model'] == cli.config['model'] + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer']