From b04b51593158c1db6658d0a7c00f0b4dde2c93f9 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 3 Nov 2020 08:18:01 +0100 Subject: [PATCH 01/35] Added new trainer_cli that reduces boilerplate --- docs/source/trainer_cli.rst | 156 +++++++++++ .../utilities/jsonargparse_utils.py | 150 +++++++++++ requirements.txt | 1 + tests/trainer/test_trainer_cli_new.py | 245 ++++++++++++++++++ 4 files changed, 552 insertions(+) create mode 100644 docs/source/trainer_cli.rst create mode 100644 pytorch_lightning/utilities/jsonargparse_utils.py create mode 100644 tests/trainer/test_trainer_cli_new.py diff --git a/docs/source/trainer_cli.rst b/docs/source/trainer_cli.rst new file mode 100644 index 0000000000000..da6163d12fb76 --- /dev/null +++ b/docs/source/trainer_cli.rst @@ -0,0 +1,156 @@ +Trainer CLI and config files +---------------------------- + +Another source of boilerplate code that Lightning can help to reduce is in the +implementation of training command line tools. Furthermore, it provides a +standardized way to configure trainings using a single file that includes +settings for :class:`~pytorch_lightning.trainer.trainer.Trainer` and user +extended :class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The +full configuration is automatically saved in the log directory. This has the +benefit of greatly simplifying the reproducibility of experiments. + +The main requirement for user extended classes to be made configurable is that +all relevant init arguments must have type hints. This is not a very demanding +requirement since anyway it is good practice to do anyway. As a bonus if the +arguments are also described in the docstrings, then the help of the training +tool will display them. + +---------- + + +training_cli +^^^^^^^^^^^^ + +The case in which the user's :class:`LightningModule` class implements all +required :code:`*_dataloader` methods, a :code:`trainer.py` tool can be as +simple as: + +.. code-block:: python + + from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli + from mycode import LitModel + + trainer_cli(LitModel) + +The help of the tool describing all configurable options and default values can +be shown by running :code:`python trainer.py --help`. Default options can be +changed by providing individual command line arguments. However, it is better +practice to create a configuration file and provide this to the trainer. A way +to do this would be: + +.. code-block:: bash + + # Dump default configuration to have as reference + python trainer.py --print-config > default_config.yaml + # Create config including only options to modify + nano config.yaml + # Run training using created configuration + python trainer.py --cfg config.yaml + +The call to the :func:`trainer_cli` function takes care of parsing command line +and config file options, instantiating the classes, setting up a callback to +save the config in the log directory and finally running :func:`trainer.fit`. + +After multiple trainings with different configurations, a previous run can be +trivially reproduced by using the config in the respective log directory, e.g.: + +.. code-block:: bash + + python trainer.py --cfg lightning_logs/version_7/config.yaml + +The start of a possible implementation of :class:`LitModel` including the +recommended argument descriptions in the docstring could be the one below. Note +that by using type hints and docstrings there is no need to duplicate this +information to define its configurable arguments. + +.. code-block:: python + + class LitModel(LightningModule): + + def __init__(self, + encoder_layers: int = 12, + decoder_layers: List[int] = [2, 4]): + """Example encoder-decoder model + + Args: + encoder_layers: Number of layers for the encoder + decoder_layers: Number of layers for each decoder block + """ + ... + +If a separate :class:`LightningDataModule` class is required, the trainer tool +just needs a small modification as follows: + +.. code-block:: python + + from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli + from mycode import LitModel, LitDataModule + + trainer_cli(LitModel, LitDataModule) + + +LightningArgumentParser +^^^^^^^^^^^^^^^^^^^^^^^ + +Even though :func:`trainer_cli` can reduce boilerplate code to a minimum, +clearly there are cases in which it is not enough. For this Lightning provides +the :class:`LightningArgumentParser` class which is an extension of the built-in +Python ArgumentParser that makes it very simple to implement configurable +training tools with the same features as :func:`trainer_cli`. + +An example of a more complex training tool could be one in which there are +several independent modules that require configuration. The code for such a case +could look something like: + +.. code-block:: python + + from pytorch_lightning.utilities.jsonargparse_utils import LightningArgumentParser, SaveConfigCallback + from mycode import LitModule1, LitModule2, LitModel, LitDataModule + + # Define parser + parser = LightningArgumentParser(description='pytorch-lightning trainer', + parse_as_dict=True) + parser.add_trainer_args() + parser.add_module_args(LitModule1, 'module1') + parser.add_module_args(LitModule2, 'module2') + parser.add_datamodule_args(LitDataModule) + + # Parse configuration + config = parser.parse_args() + + # Instantiate classes + module1 = LitModule1(**config['module1']) + module2 = LitModule2(**config['module2']) + model = LitModel(module1, module2) + datamodule = LitDataModule(**config['data']) + config['trainer']['callbacks'] = [SaveConfigCallback(parser, config)] + trainer = Trainer(**config['trainer']) + + # Start training + trainer.fit(model, datamodule) + +Note that the configuration object has all options for each module, data and +trainer in different dict keys. The structure of the yaml configuration file is +analogous. Reproducing the training can also be done with the config saved in +the log directory. + +The parser is like any other from argparse, thus it can be used to include +global options, for example: + +.. code-block:: python + + parser.add_argument('--notification_email', default='will@email.com') + +The argument parser is also able to parse environment variables. To enable this +feature, initialize :class:`LightningArgumentParser` including +:code:`default_env=True, env_prefix='PL'`. With this for instance the +:code:`PL_TRAINER__MAX_EPOCHS` environment variable if set would be used to +override the default :code:`max_epochs` of the trainer. Similarly options for +the data module could be set using variables that start with :code:`PL_DATA_` +and likewise for the modules. + +Arguments from any other class that have appropriate type hints can also be +added. An example which would store the options for a class :class:`MyClass` in +the :code:`myclass` key of the configuration object would be +:code:`parser.add_class_arguments(MyClass, 'myclass')`. diff --git a/pytorch_lightning/utilities/jsonargparse_utils.py b/pytorch_lightning/utilities/jsonargparse_utils.py new file mode 100644 index 0000000000000..6de09ab5ba8ee --- /dev/null +++ b/pytorch_lightning/utilities/jsonargparse_utils.py @@ -0,0 +1,150 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Type +from jsonargparse import ArgumentParser, ActionConfigFile +from pytorch_lightning import Trainer, LightningModule, LightningDataModule +from pytorch_lightning.callbacks import Callback + + +class LightningArgumentParser(ArgumentParser): + def __init__(self, *args, **kwargs): + """Initialize argument parser that supports configuration file input""" + super().__init__(*args, **kwargs) + self.add_argument('--cfg', + action=ActionConfigFile, + help='Path to a configuration file in json or yaml format.') + + + def add_trainer_args( + self, + trainer_class: Type[Trainer] = Trainer, + nested_key: str = 'trainer' + ): + """ + Adds arguments from a trainer class to a nested key of the parser + + Args: + trainer_class: Optional extension of the Trainer class. + nested_key: Name of the nested namespace where parsed arguments are stored. + """ + assert issubclass(trainer_class, Trainer) + self.add_class_arguments(trainer_class, nested_key) + + + def add_module_args( + self, + module_class: Type[LightningModule], + nested_key: str = 'module' + ): + """ + Adds arguments from a module class to a nested key of the parser + + Args: + module_class: A LightningModule class. + nested_key: Name of the nested namespace where parsed arguments are stored. + """ + assert issubclass(module_class, LightningModule) + self.add_class_arguments(module_class, nested_key) + + + def add_datamodule_args( + self, + datamodule_class: Type[LightningDataModule], + nested_key: str = 'data' + ): + """ + Adds arguments from a datamodule class to a nested key of the parser + + Args: + datamodule_class: A LightningDataModule class. + nested_key: Name of the nested namespace where parsed arguments are stored. + """ + assert issubclass(datamodule_class, LightningDataModule) + self.add_class_arguments(datamodule_class, nested_key) + + +class SaveConfigCallback(Callback): + """Callback that saves a trainer_cli config to the log_dir when training starts""" + + def __init__(self, parser, config): + self.config_dump = parser.dump(config, skip_none=False) + + + def on_train_start(self, trainer, pl_module): + config_path = os.path.join(trainer.logger.log_dir, 'config.yaml') + with open(config_path, 'w') as outstream: + outstream.write(self.config_dump) + + +def trainer_cli( + model_class: Type[LightningModule], + datamodule_class: Type[LightningDataModule] = None, + save_config_callback: Type[Callback] = SaveConfigCallback, + trainer_class: Type[Trainer] = Trainer, + description: str = 'pytorch-lightning trainer command line tool', + parse_env: bool = False, +): + """ + Implementation of a simple configurable Trainer command line tool + + Receives as input pytorch-lightning classes, which are instantiated using a + parsed configuration file or command line options and then runs trainer.fit. + + Example, first implement the trainer.py tool as:: + + from mymodels import MyModel + from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli + trainer_cli(MyModel) + + Then in a shell, run the tool with the desired configuration:: + + $ python trainer.py --print-config > config.yaml + $ nano config.yaml # modify the config as desired + $ python trainer.py --cfg config.yaml + + Args: + model_class: The LightningModule class to train on. + datamodule_class: An optional LightningDataModule class. + save_config_callback: A callback class to save the training config. + trainer_class: An optional extension of the Trainer class. + description: Description of the tool shown when running --help. + parse_env: Whether environment variables are also parsed. + """ + # Define parser + parser = LightningArgumentParser(description=description, + parse_as_dict=True, + default_env=parse_env, + env_prefix='PL') + parser.add_trainer_args(trainer_class, 'trainer') + parser.add_module_args(model_class, 'model') + if datamodule_class is not None: + parser.add_datamodule_args(datamodule_class, 'data') + + # Parse configuration + config = parser.parse_args() + + # Instantiate classes + model = model_class(**config.get('model', {})) + kwargs = {'model': model} + if datamodule_class is not None: + kwargs['datamodule'] = datamodule_class(**config.get('data', {})) + + if save_config_callback is not None: + config['trainer']['callbacks'] = [save_config_callback(parser, config)] + trainer = Trainer(**config['trainer']) + + # Start training + trainer.fit(**kwargs) diff --git a/requirements.txt b/requirements.txt index d270e2bc5d854..85507150e0ca5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ torch>=1.3,<1.8 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 +jsonargparse[signatures]>=3.0.0.dev4 # New trainer_cli requirement tqdm>=4.41.0 fsspec>=0.8.0 tensorboard>=2.2.0 diff --git a/tests/trainer/test_trainer_cli_new.py b/tests/trainer/test_trainer_cli_new.py new file mode 100644 index 0000000000000..063c60f970679 --- /dev/null +++ b/tests/trainer/test_trainer_cli_new.py @@ -0,0 +1,245 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import pickle +import sys +from argparse import Namespace +from unittest import mock + +import pytest +import torch + +import tests.base.develop_utils as tutils +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.utilities.jsonargparse_utils import LightningArgumentParser, SaveConfigCallback, trainer_cli + + +@mock.patch('argparse.ArgumentParser.parse_args') +def test_default_args(mock_argparse, tmpdir): + """Tests default argument parser for Trainer""" + mock_argparse.return_value = Namespace(**Trainer.default_attributes()) + + # logger file to get meta + logger = tutils.get_default_logger(tmpdir) + + parser = LightningArgumentParser(add_help=False) + args = parser.parse_args([]) + args.logger = logger + + args.max_epochs = 5 + trainer = Trainer.from_argparse_args(args) + + assert isinstance(trainer, Trainer) + assert trainer.max_epochs == 5 + + +@pytest.mark.parametrize('cli_args', [ + ['--accumulate_grad_batches=22'], + ['--weights_save_path=./'], + [] +]) +def test_add_argparse_args_redefined(cli_args): + """Redefines some default Trainer arguments via the cli and + tests the Trainer initialization correctness. + """ + parser = LightningArgumentParser(add_help=False) + parser.add_trainer_args(Trainer, None) + + args = parser.parse_args(cli_args) + + # make sure we can pickle args + pickle.dumps(args) + + # Check few deprecated args are not in namespace: + for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'): + assert depr_name not in args + + trainer = Trainer.from_argparse_args(args=args) + pickle.dumps(trainer) + + assert isinstance(trainer, Trainer) + + +@pytest.mark.parametrize('cli_args', [ + ['--callbacks=1', '--logger'], + ['--foo', '--bar=1'] +]) +def test_add_argparse_args_redefined_error(cli_args, monkeypatch): + """Asserts that an error raised in case of passing not default cli arguments.""" + + class _UnkArgError(Exception): + pass + + def _raise(): + raise _UnkArgError + + parser = LightningArgumentParser(add_help=False) + parser.add_trainer_args(Trainer, None) + + monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True) + + with pytest.raises(_UnkArgError): + parser.parse_args(cli_args) + + +@pytest.mark.parametrize(['cli_args', 'expected'], [ + #pytest.param('--auto_lr_find --auto_scale_batch_size power', + # {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), + #pytest.param('--auto_lr_find any_string --auto_scale_batch_size', + # {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}), + #pytest.param('--auto_lr_find t --auto_scale_batch_size ON', + # {'auto_lr_find': True, 'auto_scale_batch_size': True}), + #pytest.param('--auto_lr_find 0 --auto_scale_batch_size n', + # {'auto_lr_find': False, 'auto_scale_batch_size': False}), + pytest.param('--auto_lr_find TRUE --auto_scale_batch_size FALSE', + {'auto_lr_find': True, 'auto_scale_batch_size': False}), + pytest.param('--tpu_cores=8', + {'tpu_cores': 8}), + pytest.param('--tpu_cores=1,', + {'tpu_cores': '1,'}), + pytest.param('--limit_train_batches=100', + {'limit_train_batches': 100}), + pytest.param('--limit_train_batches 0.8', + {'limit_train_batches': 0.8}), + pytest.param('--weights_summary=null', + {'weights_summary': None}), + pytest.param( + "", + { + # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. + # They should not be changed by the argparse interface. + "min_steps": None, + "max_steps": None, + "log_gpu_memory": None, + "distributed_backend": None, + "weights_save_path": None, + "truncated_bptt_steps": None, + "resume_from_checkpoint": None, + "profiler": None, + }), +]) +def test_parse_args_parsing(cli_args, expected): + """Test parsing simple types and None optionals not modified.""" + cli_args = cli_args.split(' ') if cli_args else [] + parser = LightningArgumentParser(add_help=False) + parser.add_trainer_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize(['cli_args', 'expected', 'instantiate'], [ + pytest.param(['--gpus', '[0, 2]'], + {'gpus': [0, 2]}, + False), + pytest.param(['--tpu_cores=[1,3]'], + {'tpu_cores': [1, 3]}, + False), + pytest.param(['--accumulate_grad_batches={"5":3,"10":20}'], + {'accumulate_grad_batches': {5: 3, 10: 20}}, + True), +]) +def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): + """Test parsing complex types.""" + parser = LightningArgumentParser(add_help=False) + parser.add_trainer_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + if instantiate: + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [ + pytest.param('--gpus 1', [0]), + pytest.param('--gpus 0,', [0]), +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_parse_args_parsing_gpus(cli_args, expected_gpu): + """Test parsing of gpus and instantiation of Trainer.""" + cli_args = cli_args.split(' ') if cli_args else [] + parser = LightningArgumentParser(add_help=False) + parser.add_trainer_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + trainer = Trainer.from_argparse_args(args) + assert trainer.data_parallel_device_ids == expected_gpu + + +@pytest.mark.skipif(sys.version_info < (3, 7), + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec") +@pytest.mark.parametrize(['cli_args', 'extra_args'], [ + pytest.param({}, {}), + pytest.param({'logger': False}, {}), + pytest.param({'logger': False}, {'logger': True}), + pytest.param({'logger': False}, {'checkpoint_callback': True}), +]) +def test_init_from_argparse_args(cli_args, extra_args): + unknown_args = dict(unknown_arg=0) + + # unkown args in the argparser/namespace should be ignored + with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init: + trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) + expected = dict(cli_args) + expected.update(extra_args) # extra args should override any cli arg + init.assert_called_with(trainer, **expected) + + # passing in unknown manual args should throw an error + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): + Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) + + +@pytest.mark.parametrize(['cli_args', 'expected_model', 'expected_trainer'], [ + pytest.param(['--model.model_param=7', '--trainer.limit_train_batches=100'], + {'model_param': 7}, + {'limit_train_batches': 100}), +]) +def test_trainer_cli(cli_args, expected_model, expected_trainer, monkeypatch): + """Test that trainer_cli correctly instantiates model, trainer and calls fit.""" + + def fit(trainer, model): + for k, v in model.expected_model.items(): + assert getattr(model, k) == v + for k, v in model.expected_trainer.items(): + assert getattr(trainer, k) == v + save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] + assert len(save_callback) == 1 + save_callback[0].on_train_start(trainer, model) + + def on_train_start(callback, trainer, model): + for k, v in model.expected_model.items(): + assert f' {k}: {v}' in callback.config_dump + for k, v in model.expected_trainer.items(): + assert f' {k}: {v}' in callback.config_dump + + monkeypatch.setattr(Trainer, 'fit', fit) + monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) + + class TestModel(LightningModule): + def __init__(self, model_param: int): + super().__init__() + self.model_param = model_param + + TestModel.expected_model = expected_model + TestModel.expected_trainer = expected_trainer + + with mock.patch('sys.argv', ['any.py'] + cli_args): + trainer_cli(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) From 186b6d812796cef19c0878fed22f457cbecc92a0 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 18 Nov 2020 09:32:37 +0100 Subject: [PATCH 02/35] - Converted trainer_cli function into a class that can be more easily customized. - Changes based on pull request reviews. --- docs/source/trainer_cli.rst | 300 ++++++++++++------ .../utilities/jsonargparse_utils.py | 150 --------- pytorch_lightning/utilities/trainer_cli.py | 230 ++++++++++++++ requirements.txt | 2 +- tests/trainer/test_trainer_cli_new.py | 45 +-- 5 files changed, 465 insertions(+), 262 deletions(-) delete mode 100644 pytorch_lightning/utilities/jsonargparse_utils.py create mode 100644 pytorch_lightning/utilities/trainer_cli.py diff --git a/docs/source/trainer_cli.rst b/docs/source/trainer_cli.rst index da6163d12fb76..4bb6f853eb92a 100644 --- a/docs/source/trainer_cli.rst +++ b/docs/source/trainer_cli.rst @@ -1,3 +1,37 @@ +.. testsetup:: * + + from typing import List + from pytorch_lightning import LightningModule, LightningDataModule + from pytorch_lightning.utilities.trainer_cli import TrainerCli + + original_fit = TrainerCli.fit + TrainerCli.fit = lambda self: None + + class MyModel(LightningModule): + def __init__( + self, + encoder_layers: int = 12, + decoder_layers: List[int] = [2, 4] + ): + """Example encoder-decoder model + + Args: + encoder_layers: Number of layers for the encoder + decoder_layers: Number of layers for each decoder block + """ + pass + + class MyDataModule(LightningDataModule): + pass + + def send_email(address, message): + pass + +.. testcleanup:: * + + TrainerCli.fit = original_fit + + Trainer CLI and config files ---------------------------- @@ -12,26 +46,26 @@ benefit of greatly simplifying the reproducibility of experiments. The main requirement for user extended classes to be made configurable is that all relevant init arguments must have type hints. This is not a very demanding -requirement since anyway it is good practice to do anyway. As a bonus if the -arguments are also described in the docstrings, then the help of the training -tool will display them. +requirement since it is good practice to do anyway. As a bonus if the arguments +are described in the docstrings, then the help of the training tool will display +them. ---------- -training_cli -^^^^^^^^^^^^ +TrainerCli +^^^^^^^^^^ -The case in which the user's :class:`LightningModule` class implements all +The case in which the user's +:class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required :code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as: -.. code-block:: python +.. testcode:: - from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli - from mycode import LitModel + from pytorch_lightning.utilities.trainer_cli import TrainerCli - trainer_cli(LitModel) + TrainerCli(MyModel) The help of the tool describing all configurable options and default values can be shown by running :code:`python trainer.py --help`. Default options can be @@ -42,15 +76,16 @@ to do this would be: .. code-block:: bash # Dump default configuration to have as reference - python trainer.py --print-config > default_config.yaml + python trainer.py --print_config > default_config.yaml # Create config including only options to modify nano config.yaml # Run training using created configuration - python trainer.py --cfg config.yaml + python trainer.py --config config.yaml -The call to the :func:`trainer_cli` function takes care of parsing command line -and config file options, instantiating the classes, setting up a callback to -save the config in the log directory and finally running :func:`trainer.fit`. +The call to the :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` +class takes care of parsing command line and config file options, instantiating +the classes, setting up a callback to save the config in the log directory and +finally running :func:`trainer.fit`. After multiple trainings with different configurations, a previous run can be trivially reproduced by using the config in the respective log directory, e.g.: @@ -59,18 +94,29 @@ trivially reproduced by using the config in the respective log directory, e.g.: python trainer.py --cfg lightning_logs/version_7/config.yaml -The start of a possible implementation of :class:`LitModel` including the +If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` +class is required, the trainer tool just needs a small modification as follows: + +.. testcode:: + + from pytorch_lightning.utilities.trainer_cli import TrainerCli + + TrainerCli(MyModel, MyDataModule) + +The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this information to define its configurable arguments. .. code-block:: python - class LitModel(LightningModule): + class MyModel(LightningModule): - def __init__(self, - encoder_layers: int = 12, - decoder_layers: List[int] = [2, 4]): + def __init__( + self, + encoder_layers: int = 12, + decoder_layers: List[int] = [2, 4] + ): """Example encoder-decoder model Args: @@ -79,78 +125,150 @@ information to define its configurable arguments. """ ... -If a separate :class:`LightningDataModule` class is required, the trainer tool -just needs a small modification as follows: +With this model class, the help of the trainer tool would look as follows: -.. code-block:: python - - from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli - from mycode import LitModel, LitDataModule - - trainer_cli(LitModel, LitDataModule) - - -LightningArgumentParser -^^^^^^^^^^^^^^^^^^^^^^^ - -Even though :func:`trainer_cli` can reduce boilerplate code to a minimum, -clearly there are cases in which it is not enough. For this Lightning provides -the :class:`LightningArgumentParser` class which is an extension of the built-in -Python ArgumentParser that makes it very simple to implement configurable -training tools with the same features as :func:`trainer_cli`. - -An example of a more complex training tool could be one in which there are -several independent modules that require configuration. The code for such a case -could look something like: - -.. code-block:: python - - from pytorch_lightning.utilities.jsonargparse_utils import LightningArgumentParser, SaveConfigCallback - from mycode import LitModule1, LitModule2, LitModel, LitDataModule - - # Define parser - parser = LightningArgumentParser(description='pytorch-lightning trainer', - parse_as_dict=True) - parser.add_trainer_args() - parser.add_module_args(LitModule1, 'module1') - parser.add_module_args(LitModule2, 'module2') - parser.add_datamodule_args(LitDataModule) - - # Parse configuration - config = parser.parse_args() - - # Instantiate classes - module1 = LitModule1(**config['module1']) - module2 = LitModule2(**config['module2']) - model = LitModel(module1, module2) - datamodule = LitDataModule(**config['data']) - config['trainer']['callbacks'] = [SaveConfigCallback(parser, config)] - trainer = Trainer(**config['trainer']) - - # Start training - trainer.fit(model, datamodule) - -Note that the configuration object has all options for each module, data and -trainer in different dict keys. The structure of the yaml configuration file is -analogous. Reproducing the training can also be done with the config saved in -the log directory. - -The parser is like any other from argparse, thus it can be used to include -global options, for example: - -.. code-block:: python +.. code-block:: bash - parser.add_argument('--notification_email', default='will@email.com') + $ python trainer.py --help + usage: trainer.py [-h] [--print_config] [--config CONFIG] + [--trainer.logger LOGGER] + ... + + pytorch-lightning trainer command line tool + + optional arguments: + -h, --help show this help message and exit + --print_config print configuration and exit + --config CONFIG Path to a configuration file in json or yaml format. + (default: null) + + Customize every aspect of training via flags: + ... + --trainer.max_epochs MAX_EPOCHS + Stop training once this number of epochs is reached. + (type: int, default: 1000) + --trainer.min_epochs MIN_EPOCHS + Force training for at least these many epochs (type: int, + default: 1) + ... + + Example encoder-decoder model: + --model.encoder_layers ENCODER_LAYERS + Number of layers for the encoder (type: int, default: 12) + --model.decoder_layers DECODER_LAYERS + Number of layers for each decoder block (type: List[int], + default: [2, 4]) + +The default configuration that option :code:`--print_config` gives is in yaml +format and for the example above would look as follows: -The argument parser is also able to parse environment variables. To enable this -feature, initialize :class:`LightningArgumentParser` including -:code:`default_env=True, env_prefix='PL'`. With this for instance the -:code:`PL_TRAINER__MAX_EPOCHS` environment variable if set would be used to -override the default :code:`max_epochs` of the trainer. Similarly options for -the data module could be set using variables that start with :code:`PL_DATA_` -and likewise for the modules. +.. code-block:: bash -Arguments from any other class that have appropriate type hints can also be -added. An example which would store the options for a class :class:`MyClass` in -the :code:`myclass` key of the configuration object would be -:code:`parser.add_class_arguments(MyClass, 'myclass')`. + $ python trainer.py --print_config + model: + decoder_layers: + - 2 + - 4 + encoder_layers: 12 + trainer: + accelerator: null + accumulate_grad_batches: 1 + amp_backend: native + amp_level: O2 + ... + +Note that for each class, model and trainer, there is a section each with the +init parameters of the class. This grouping is also used in the formatting of +the help shown previously. + + +Customizing TrainerCli +^^^^^^^^^^^^^^^^^^^^^^ + +The init parameters of the +:class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class can be used +to customize some things. + +- :code:`save_config_callback`: By default is + :class:`~pytorch_lightning.utilities.trainer_cli.SaveConfigCallback` which is + the callback that saves the config to the log directory. It could be extended + for example to log the config as an artifact. + +- :code:`description`: The command line tool description shown in the help. + +- :code:`parse_env`: A boolean that can be used to enable parsing of environment + variables. With this for instance the :code:`PL_TRAINER__MAX_EPOCHS` + environment variable if set would be used to override the default + :code:`max_epochs` of the trainer. Similarly options for the data module could + be set using variables that start with :code:`PL_DATA_` and likewise for the + modules. + +- :code:`**kwargs`: All other keyword arguments are used to initialize the + trainer class. Thus, this can be used for instance to set callbacks. + +Even though :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` and its +init parameters can reduce boilerplate code to a minimum, clearly there are +cases in which it is not enough. The class is designed so that can be extended +to customize different parts of the command line tool. The argument parser class +used by :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` is +:class:`~pytorch_lightning.utilities.trainer_cli.LightningArgumentParser` which +is an extension of python's argparse, thus adding arguments can be done using +the :func:`add_argument` method. In contrast to argparse it has additional +methods to add arguments, for example :func:`add_class_arguments` adds all +arguments from the init of a class, though requiring parameters to have type +hints. For more details about this please refer to the `respective documentation +`_. + +The :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class has the +:meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_arguments_to_parser` +method which can be implemented to include more arguments. After parsing, the +configuration is stored in the :code:`config` attribute of the class instance. +The :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class also has +two methods that can be used to run code before and after :code:`trainer.fit` is +executed: :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.before_fit` +and :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.after_fit`. A +simple example for these would be to send an email before and after fit. The +code would be something like: + +.. testcode:: + + from pytorch_lightning.utilities.trainer_cli import TrainerCli + + class MyTrainerCli(TrainerCli): + + def add_arguments_to_parser(self): + self.parser.add_argument('--notification_email', default='will@email.com') + + def before_fit(self): + send_email( + address=self.config['notification_email'], + message='trainer.fit starting' + ) + + def after_fit(self): + send_email( + address=self.config['notification_email'], + message='trainer.fit finished' + ) + + MyTrainerCli(MyModel) + +Note that the config object :code:`self.config` is a dictionary whose keys are +global options or groups of options. It has the same structure as the yaml +format as described previously. This means for instance that the parameters used +for instantiating the trainer class can be found in +:code:`self.config['trainer']`. + +For more advanced use cases, other methods of the +:class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class could be +extended. The complete list of methods is: + +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.init_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_core_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.parse_arguments` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.instantiate_classes` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.before_fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.after_fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.run` diff --git a/pytorch_lightning/utilities/jsonargparse_utils.py b/pytorch_lightning/utilities/jsonargparse_utils.py deleted file mode 100644 index 6de09ab5ba8ee..0000000000000 --- a/pytorch_lightning/utilities/jsonargparse_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Type -from jsonargparse import ArgumentParser, ActionConfigFile -from pytorch_lightning import Trainer, LightningModule, LightningDataModule -from pytorch_lightning.callbacks import Callback - - -class LightningArgumentParser(ArgumentParser): - def __init__(self, *args, **kwargs): - """Initialize argument parser that supports configuration file input""" - super().__init__(*args, **kwargs) - self.add_argument('--cfg', - action=ActionConfigFile, - help='Path to a configuration file in json or yaml format.') - - - def add_trainer_args( - self, - trainer_class: Type[Trainer] = Trainer, - nested_key: str = 'trainer' - ): - """ - Adds arguments from a trainer class to a nested key of the parser - - Args: - trainer_class: Optional extension of the Trainer class. - nested_key: Name of the nested namespace where parsed arguments are stored. - """ - assert issubclass(trainer_class, Trainer) - self.add_class_arguments(trainer_class, nested_key) - - - def add_module_args( - self, - module_class: Type[LightningModule], - nested_key: str = 'module' - ): - """ - Adds arguments from a module class to a nested key of the parser - - Args: - module_class: A LightningModule class. - nested_key: Name of the nested namespace where parsed arguments are stored. - """ - assert issubclass(module_class, LightningModule) - self.add_class_arguments(module_class, nested_key) - - - def add_datamodule_args( - self, - datamodule_class: Type[LightningDataModule], - nested_key: str = 'data' - ): - """ - Adds arguments from a datamodule class to a nested key of the parser - - Args: - datamodule_class: A LightningDataModule class. - nested_key: Name of the nested namespace where parsed arguments are stored. - """ - assert issubclass(datamodule_class, LightningDataModule) - self.add_class_arguments(datamodule_class, nested_key) - - -class SaveConfigCallback(Callback): - """Callback that saves a trainer_cli config to the log_dir when training starts""" - - def __init__(self, parser, config): - self.config_dump = parser.dump(config, skip_none=False) - - - def on_train_start(self, trainer, pl_module): - config_path = os.path.join(trainer.logger.log_dir, 'config.yaml') - with open(config_path, 'w') as outstream: - outstream.write(self.config_dump) - - -def trainer_cli( - model_class: Type[LightningModule], - datamodule_class: Type[LightningDataModule] = None, - save_config_callback: Type[Callback] = SaveConfigCallback, - trainer_class: Type[Trainer] = Trainer, - description: str = 'pytorch-lightning trainer command line tool', - parse_env: bool = False, -): - """ - Implementation of a simple configurable Trainer command line tool - - Receives as input pytorch-lightning classes, which are instantiated using a - parsed configuration file or command line options and then runs trainer.fit. - - Example, first implement the trainer.py tool as:: - - from mymodels import MyModel - from pytorch_lightning.utilities.jsonargparse_utils import trainer_cli - trainer_cli(MyModel) - - Then in a shell, run the tool with the desired configuration:: - - $ python trainer.py --print-config > config.yaml - $ nano config.yaml # modify the config as desired - $ python trainer.py --cfg config.yaml - - Args: - model_class: The LightningModule class to train on. - datamodule_class: An optional LightningDataModule class. - save_config_callback: A callback class to save the training config. - trainer_class: An optional extension of the Trainer class. - description: Description of the tool shown when running --help. - parse_env: Whether environment variables are also parsed. - """ - # Define parser - parser = LightningArgumentParser(description=description, - parse_as_dict=True, - default_env=parse_env, - env_prefix='PL') - parser.add_trainer_args(trainer_class, 'trainer') - parser.add_module_args(model_class, 'model') - if datamodule_class is not None: - parser.add_datamodule_args(datamodule_class, 'data') - - # Parse configuration - config = parser.parse_args() - - # Instantiate classes - model = model_class(**config.get('model', {})) - kwargs = {'model': model} - if datamodule_class is not None: - kwargs['datamodule'] = datamodule_class(**config.get('data', {})) - - if save_config_callback is not None: - config['trainer']['callbacks'] = [save_config_callback(parser, config)] - trainer = Trainer(**config['trainer']) - - # Start training - trainer.fit(**kwargs) diff --git a/pytorch_lightning/utilities/trainer_cli.py b/pytorch_lightning/utilities/trainer_cli.py new file mode 100644 index 0000000000000..655605de832f5 --- /dev/null +++ b/pytorch_lightning/utilities/trainer_cli.py @@ -0,0 +1,230 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Type +from jsonargparse import ArgumentParser, ActionConfigFile +from pytorch_lightning import Trainer, LightningModule, LightningDataModule +from pytorch_lightning.callbacks import Callback + + +class LightningArgumentParser(ArgumentParser): + """Extension of jsonargparse's ArgumentParser for pythorch-lightning""" + + def __init__( + self, + *args, + parse_as_dict: bool = True, + **kwargs + ): + """Initialize argument parser that supports configuration file input""" + super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + self.add_argument( + '--config', + action=ActionConfigFile, + help='Path to a configuration file in json or yaml format.' + ) + + def add_trainer_args( + self, + trainer_class: Type[Trainer] = Trainer, + nested_key: str = 'trainer' + ): + """ + Adds arguments from a trainer class to a nested key of the parser + + Args: + trainer_class: Optional extension of the Trainer class. + nested_key: Name of the nested namespace to store arguments. + """ + assert issubclass(trainer_class, Trainer) + self.add_class_arguments(trainer_class, nested_key) + + def add_module_args( + self, + module_class: Type[LightningModule], + nested_key: str = 'module' + ): + """ + Adds arguments from a module class to a nested key of the parser + + Args: + module_class: A LightningModule class. + nested_key: Name of the nested namespace to store arguments. + """ + assert issubclass(module_class, LightningModule) + self.add_class_arguments(module_class, nested_key) + + def add_datamodule_args( + self, + datamodule_class: Type[LightningDataModule], + nested_key: str = 'data' + ): + """ + Adds arguments from a datamodule class to a nested key of the parser + + Args: + datamodule_class: A LightningDataModule class. + nested_key: Name of the nested namespace to store arguments. + """ + assert issubclass(datamodule_class, LightningDataModule) + self.add_class_arguments(datamodule_class, nested_key) + + +class SaveConfigCallback(Callback): + """Saves a TrainerCli config to the log_dir when training starts""" + + def __init__(self, parser, config): + self.config_dump = parser.dump(config, skip_none=False) + + def on_train_start(self, trainer, pl_module): + config_path = os.path.join(trainer.logger.log_dir, 'config.yaml') + with open(config_path, 'w') as outstream: + outstream.write(self.config_dump) + + +class TrainerCli: + def __init__( + self, + model_class: Type[LightningModule], + datamodule_class: Type[LightningDataModule] = None, + save_config_callback: Type[Callback] = SaveConfigCallback, + trainer_class: Type[Trainer] = Trainer, + description: str = 'pytorch-lightning trainer command line tool', + parse_env: bool = False, + **kwargs + ): + """ + Implementation of a simple configurable Trainer command line tool + + Receives as input pytorch-lightning classes, which are instantiated using a + parsed configuration file or command line args and then runs trainer.fit. + + Example, first implement the trainer.py tool as:: + + from mymodels import MyModel + from pytorch_lightning.utilities.jsonargparse_utils import TrainerCli + TrainerCli(MyModel) + + Then in a shell, run the tool with the desired configuration:: + + $ python trainer.py --print-config > config.yaml + $ nano config.yaml # modify the config as desired + $ python trainer.py --cfg config.yaml + + Args: + model_class: The LightningModule class to train on. + datamodule_class: An optional LightningDataModule class. + save_config_callback: A callback class to save the training config. + trainer_class: An optional extension of the Trainer class. + description: Description of the tool shown when running --help. + parse_env: Whether environment variables are also parsed. + **kwargs: Additional arguments to instantiate Trainer. + """ + if 'callbacks' not in kwargs: + kwargs['callbacks'] = [] + + self.model_class = model_class + self.datamodule_class = datamodule_class + self.save_config_callback = save_config_callback + self.trainer_class = trainer_class + self.trainer_kwargs = kwargs + + self.init_parser(description, parse_env) + self.add_arguments_to_parser(self.parser) + self.add_core_arguments_to_parser() + self.parse_arguments() + self.instantiate_classes() + self.run() + + + def init_parser( + self, + description: str, + parse_env: bool + ): + """Method that instantiates the argument parser + + Args: + description: Description of the tool shown when running --help. + parse_env: Whether environment variables are also parsed. + """ + self.parser = LightningArgumentParser( + description=description, + print_config='--print_config', + default_env=parse_env, + env_prefix='PL' + ) + + + def add_arguments_to_parser( + self, + parser: LightningArgumentParser + ): + """Implement to add extra arguments to parser + + Args: + parser: The argument parser object to which arguments should be added + """ + pass + + + def add_core_arguments_to_parser(self): + """Adds arguments from the core classes to the parser""" + self.parser.add_trainer_args(self.trainer_class, 'trainer') + self.parser.add_module_args(self.model_class, 'model') + if self.datamodule_class is not None: + self.parser.add_datamodule_args(self.datamodule_class, 'data') + + + def parse_arguments(self): + """Parses command line arguments and stores it in self.config""" + self.config = self.parser.parse_args() + + + def instantiate_classes(self): + """Instantiates the classes using settings from self.config and prepares fit_kwargs""" + # Instantiate model + self.model = self.model_class(**self.config.get('model', {})) + # Instantiate datamodule + self.fit_kwargs = {'model': self.model} + if self.datamodule_class is not None: + self.fit_kwargs['datamodule'] = self.datamodule_class(**self.config.get('data', {})) + # Instantiate trainer + self.trainer_kwargs.update(self.config['trainer']) + if self.save_config_callback is not None: + self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) + self.trainer = self.trainer_class(**self.trainer_kwargs) + + + def before_fit(self): + """Implement to run some code before fit is started""" + pass + + + def after_fit(self): + """Implement to run some code after fit has finished""" + pass + + + def fit(self): + """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" + self.trainer.fit(**self.fit_kwargs) + + + def run(self): + """Runs self.before_fit, then self.fit and finally self.after_fit""" + self.before_fit() + self.fit() + self.after_fit() diff --git a/requirements.txt b/requirements.txt index 85507150e0ca5..26cd56f1d1df2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch>=1.3,<1.8 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 -jsonargparse[signatures]>=3.0.0.dev4 # New trainer_cli requirement +jsonargparse[signatures]>=3.0.0rc1 # New trainer_cli requirement tqdm>=4.41.0 fsspec>=0.8.0 tensorboard>=2.2.0 diff --git a/tests/trainer/test_trainer_cli_new.py b/tests/trainer/test_trainer_cli_new.py index 063c60f970679..d7e3c394bdbed 100644 --- a/tests/trainer/test_trainer_cli_new.py +++ b/tests/trainer/test_trainer_cli_new.py @@ -23,7 +23,11 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.jsonargparse_utils import LightningArgumentParser, SaveConfigCallback, trainer_cli +from pytorch_lightning.utilities.trainer_cli import ( + LightningArgumentParser, + SaveConfigCallback, + TrainerCli +) @mock.patch('argparse.ArgumentParser.parse_args') @@ -34,7 +38,7 @@ def test_default_args(mock_argparse, tmpdir): # logger file to get meta logger = tutils.get_default_logger(tmpdir) - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) args = parser.parse_args([]) args.logger = logger @@ -54,7 +58,7 @@ def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness. """ - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_trainer_args(Trainer, None) args = parser.parse_args(cli_args) @@ -77,7 +81,7 @@ def test_add_argparse_args_redefined(cli_args): ['--foo', '--bar=1'] ]) def test_add_argparse_args_redefined_error(cli_args, monkeypatch): - """Asserts that an error raised in case of passing not default cli arguments.""" + """Asserts error raised in case of passing not default cli arguments.""" class _UnkArgError(Exception): pass @@ -85,7 +89,7 @@ class _UnkArgError(Exception): def _raise(): raise _UnkArgError - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_trainer_args(Trainer, None) monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True) @@ -95,14 +99,14 @@ def _raise(): @pytest.mark.parametrize(['cli_args', 'expected'], [ - #pytest.param('--auto_lr_find --auto_scale_batch_size power', - # {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), - #pytest.param('--auto_lr_find any_string --auto_scale_batch_size', - # {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}), - #pytest.param('--auto_lr_find t --auto_scale_batch_size ON', - # {'auto_lr_find': True, 'auto_scale_batch_size': True}), - #pytest.param('--auto_lr_find 0 --auto_scale_batch_size n', - # {'auto_lr_find': False, 'auto_scale_batch_size': False}), + pytest.param('--auto_lr_find=True --auto_scale_batch_size=power', + {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), + pytest.param('--auto_lr_find any_string --auto_scale_batch_size ON', + {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}), + pytest.param('--auto_lr_find=Yes --auto_scale_batch_size=On', + {'auto_lr_find': True, 'auto_scale_batch_size': True}), + pytest.param('--auto_lr_find Off --auto_scale_batch_size No', + {'auto_lr_find': False, 'auto_scale_batch_size': False}), pytest.param('--auto_lr_find TRUE --auto_scale_batch_size FALSE', {'auto_lr_find': True, 'auto_scale_batch_size': False}), pytest.param('--tpu_cores=8', @@ -118,8 +122,9 @@ def _raise(): pytest.param( "", { - # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. - # They should not be changed by the argparse interface. + # These parameters are marked as Optional[...] in Trainer.__init__, + # with None as default. They should not be changed by the argparse + # interface. "min_steps": None, "max_steps": None, "log_gpu_memory": None, @@ -133,7 +138,7 @@ def _raise(): def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(' ') if cli_args else [] - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_trainer_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -156,7 +161,7 @@ def test_parse_args_parsing(cli_args, expected): ]) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_trainer_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -171,11 +176,11 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): pytest.param('--gpus 1', [0]), pytest.param('--gpus 0,', [0]), ]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") def test_parse_args_parsing_gpus(cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" cli_args = cli_args.split(' ') if cli_args else [] - parser = LightningArgumentParser(add_help=False) + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_trainer_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -242,4 +247,4 @@ def __init__(self, model_param: int): TestModel.expected_trainer = expected_trainer with mock.patch('sys.argv', ['any.py'] + cli_args): - trainer_cli(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) + TrainerCli(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) From d286f49d8dc84344d2b36476bceeac9bd1c7524d Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 18 Nov 2020 10:11:00 +0100 Subject: [PATCH 03/35] - Fixes required by pep8speaks in trainer_cli.py. - Fixes to testsetup in trainer_cli.rst. --- docs/source/trainer_cli.rst | 3 ++- pytorch_lightning/utilities/trainer_cli.py | 9 --------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/docs/source/trainer_cli.rst b/docs/source/trainer_cli.rst index 4bb6f853eb92a..bc66c6ac7c89b 100644 --- a/docs/source/trainer_cli.rst +++ b/docs/source/trainer_cli.rst @@ -1,7 +1,8 @@ .. testsetup:: * from typing import List - from pytorch_lightning import LightningModule, LightningDataModule + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.trainer_cli import TrainerCli original_fit = TrainerCli.fit diff --git a/pytorch_lightning/utilities/trainer_cli.py b/pytorch_lightning/utilities/trainer_cli.py index 655605de832f5..a1e0d3bd62b89 100644 --- a/pytorch_lightning/utilities/trainer_cli.py +++ b/pytorch_lightning/utilities/trainer_cli.py @@ -148,7 +148,6 @@ def __init__( self.instantiate_classes() self.run() - def init_parser( self, description: str, @@ -167,7 +166,6 @@ def init_parser( env_prefix='PL' ) - def add_arguments_to_parser( self, parser: LightningArgumentParser @@ -179,7 +177,6 @@ def add_arguments_to_parser( """ pass - def add_core_arguments_to_parser(self): """Adds arguments from the core classes to the parser""" self.parser.add_trainer_args(self.trainer_class, 'trainer') @@ -187,12 +184,10 @@ def add_core_arguments_to_parser(self): if self.datamodule_class is not None: self.parser.add_datamodule_args(self.datamodule_class, 'data') - def parse_arguments(self): """Parses command line arguments and stores it in self.config""" self.config = self.parser.parse_args() - def instantiate_classes(self): """Instantiates the classes using settings from self.config and prepares fit_kwargs""" # Instantiate model @@ -207,22 +202,18 @@ def instantiate_classes(self): self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) self.trainer = self.trainer_class(**self.trainer_kwargs) - def before_fit(self): """Implement to run some code before fit is started""" pass - def after_fit(self): """Implement to run some code after fit has finished""" pass - def fit(self): """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" self.trainer.fit(**self.fit_kwargs) - def run(self): """Runs self.before_fit, then self.fit and finally self.after_fit""" self.before_fit() From ed64f31e916c66123b91331992fde79512e19b84 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 19 Nov 2020 09:07:27 +0100 Subject: [PATCH 04/35] Renamed class to LightningCLI and other minor fixes --- docs/source/trainer_cli.rst | 74 +++++++++++----------- pytorch_lightning/utilities/trainer_cli.py | 14 ++-- tests/trainer/test_trainer_cli_new.py | 4 +- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/docs/source/trainer_cli.rst b/docs/source/trainer_cli.rst index bc66c6ac7c89b..766e1d263b7fb 100644 --- a/docs/source/trainer_cli.rst +++ b/docs/source/trainer_cli.rst @@ -3,10 +3,10 @@ from typing import List from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule - from pytorch_lightning.utilities.trainer_cli import TrainerCli + from pytorch_lightning.utilities.trainer_cli import LightningCLI - original_fit = TrainerCli.fit - TrainerCli.fit = lambda self: None + original_fit = LightningCLI.fit + LightningCLI.fit = lambda self: None class MyModel(LightningModule): def __init__( @@ -30,11 +30,11 @@ .. testcleanup:: * - TrainerCli.fit = original_fit + LightningCLI.fit = original_fit -Trainer CLI and config files ----------------------------- +Lightning CLI and config files +------------------------------ Another source of boilerplate code that Lightning can help to reduce is in the implementation of training command line tools. Furthermore, it provides a @@ -54,8 +54,8 @@ them. ---------- -TrainerCli -^^^^^^^^^^ +LightningCLI +^^^^^^^^^^^^ The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all @@ -64,9 +64,9 @@ simple as: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import TrainerCli + from pytorch_lightning.utilities.trainer_cli import LightningCLI - TrainerCli(MyModel) + LightningCLI(MyModel) The help of the tool describing all configurable options and default values can be shown by running :code:`python trainer.py --help`. Default options can be @@ -83,7 +83,7 @@ to do this would be: # Run training using created configuration python trainer.py --config config.yaml -The call to the :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` +The call to the :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class takes care of parsing command line and config file options, instantiating the classes, setting up a callback to save the config in the log directory and finally running :func:`trainer.fit`. @@ -93,16 +93,16 @@ trivially reproduced by using the config in the respective log directory, e.g.: .. code-block:: bash - python trainer.py --cfg lightning_logs/version_7/config.yaml + python trainer.py --config lightning_logs/version_7/config.yaml If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class is required, the trainer tool just needs a small modification as follows: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import TrainerCli + from pytorch_lightning.utilities.trainer_cli import LightningCLI - TrainerCli(MyModel, MyDataModule) + LightningCLI(MyModel, MyDataModule) The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the docstring could be the one below. Note @@ -183,11 +183,11 @@ init parameters of the class. This grouping is also used in the formatting of the help shown previously. -Customizing TrainerCli -^^^^^^^^^^^^^^^^^^^^^^ +Customizing LightningCLI +^^^^^^^^^^^^^^^^^^^^^^^^ The init parameters of the -:class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class can be used +:class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class can be used to customize some things. - :code:`save_config_callback`: By default is @@ -207,11 +207,11 @@ to customize some things. - :code:`**kwargs`: All other keyword arguments are used to initialize the trainer class. Thus, this can be used for instance to set callbacks. -Even though :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` and its +Even though :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` and its init parameters can reduce boilerplate code to a minimum, clearly there are cases in which it is not enough. The class is designed so that can be extended to customize different parts of the command line tool. The argument parser class -used by :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` is +used by :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` is :class:`~pytorch_lightning.utilities.trainer_cli.LightningArgumentParser` which is an extension of python's argparse, thus adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional @@ -220,22 +220,22 @@ arguments from the init of a class, though requiring parameters to have type hints. For more details about this please refer to the `respective documentation `_. -The :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class has the -:meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_arguments_to_parser` +The :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class has the +:meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. -The :class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class also has +The :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class also has two methods that can be used to run code before and after :code:`trainer.fit` is -executed: :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.before_fit` -and :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.after_fit`. A +executed: :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.before_fit` +and :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.after_fit`. A simple example for these would be to send an email before and after fit. The code would be something like: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import TrainerCli + from pytorch_lightning.utilities.trainer_cli import LightningCLI - class MyTrainerCli(TrainerCli): + class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self): self.parser.add_argument('--notification_email', default='will@email.com') @@ -252,7 +252,7 @@ code would be something like: message='trainer.fit finished' ) - MyTrainerCli(MyModel) + MyLightningCLI(MyModel) Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It has the same structure as the yaml @@ -261,15 +261,15 @@ for instantiating the trainer class can be found in :code:`self.config['trainer']`. For more advanced use cases, other methods of the -:class:`~pytorch_lightning.utilities.trainer_cli.TrainerCli` class could be +:class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class could be extended. The complete list of methods is: -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.init_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.add_core_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.parse_arguments` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.instantiate_classes` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.before_fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.after_fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.TrainerCli.run` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.init_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_core_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.parse_arguments` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.instantiate_classes` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.before_fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.after_fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.fit` +- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.run` diff --git a/pytorch_lightning/utilities/trainer_cli.py b/pytorch_lightning/utilities/trainer_cli.py index a1e0d3bd62b89..959b80d45890c 100644 --- a/pytorch_lightning/utilities/trainer_cli.py +++ b/pytorch_lightning/utilities/trainer_cli.py @@ -15,12 +15,14 @@ import os from typing import Type from jsonargparse import ArgumentParser, ActionConfigFile -from pytorch_lightning import Trainer, LightningModule, LightningDataModule +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.callbacks import Callback class LightningArgumentParser(ArgumentParser): - """Extension of jsonargparse's ArgumentParser for pythorch-lightning""" + """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" def __init__( self, @@ -83,7 +85,7 @@ def add_datamodule_args( class SaveConfigCallback(Callback): - """Saves a TrainerCli config to the log_dir when training starts""" + """Saves a LightningCLI config to the log_dir when training starts""" def __init__(self, parser, config): self.config_dump = parser.dump(config, skip_none=False) @@ -94,7 +96,7 @@ def on_train_start(self, trainer, pl_module): outstream.write(self.config_dump) -class TrainerCli: +class LightningCLI: def __init__( self, model_class: Type[LightningModule], @@ -114,8 +116,8 @@ def __init__( Example, first implement the trainer.py tool as:: from mymodels import MyModel - from pytorch_lightning.utilities.jsonargparse_utils import TrainerCli - TrainerCli(MyModel) + from pytorch_lightning.utilities.jsonargparse_utils import LightningCLI + LightningCLI(MyModel) Then in a shell, run the tool with the desired configuration:: diff --git a/tests/trainer/test_trainer_cli_new.py b/tests/trainer/test_trainer_cli_new.py index d7e3c394bdbed..c5ebcfabd2ff3 100644 --- a/tests/trainer/test_trainer_cli_new.py +++ b/tests/trainer/test_trainer_cli_new.py @@ -26,7 +26,7 @@ from pytorch_lightning.utilities.trainer_cli import ( LightningArgumentParser, SaveConfigCallback, - TrainerCli + LightningCLI ) @@ -247,4 +247,4 @@ def __init__(self, model_param: int): TestModel.expected_trainer = expected_trainer with mock.patch('sys.argv', ['any.py'] + cli_args): - TrainerCli(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) + LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) From 35d3aa304812e67d853304934cf44b515ad6e1b7 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 30 Nov 2020 11:03:30 +0100 Subject: [PATCH 05/35] - Fixed bug in testcode of trainer_cli.rst. - Added default_config_files to LightningCLI. - Separate methods in LightningCLI for instantiating model, data and trainer. --- docs/source/trainer_cli.rst | 4 +-- pytorch_lightning/utilities/trainer_cli.py | 37 +++++++++++++++------- requirements.txt | 2 +- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/docs/source/trainer_cli.rst b/docs/source/trainer_cli.rst index 766e1d263b7fb..1e92d46230f1e 100644 --- a/docs/source/trainer_cli.rst +++ b/docs/source/trainer_cli.rst @@ -237,8 +237,8 @@ code would be something like: class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self): - self.parser.add_argument('--notification_email', default='will@email.com') + def add_arguments_to_parser(self, parser): + parser.add_argument('--notification_email', default='will@email.com') def before_fit(self): send_email( diff --git a/pytorch_lightning/utilities/trainer_cli.py b/pytorch_lightning/utilities/trainer_cli.py index 959b80d45890c..fa44395835b44 100644 --- a/pytorch_lightning/utilities/trainer_cli.py +++ b/pytorch_lightning/utilities/trainer_cli.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Type +from typing import Type, List, Optional from jsonargparse import ArgumentParser, ActionConfigFile from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule @@ -104,6 +104,7 @@ def __init__( save_config_callback: Type[Callback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, description: str = 'pytorch-lightning trainer command line tool', + default_config_files: List[str] = None, parse_env: bool = False, **kwargs ): @@ -131,6 +132,7 @@ def __init__( save_config_callback: A callback class to save the training config. trainer_class: An optional extension of the Trainer class. description: Description of the tool shown when running --help. + default_config_files: Default config file locations, e.g. :code:`['~/.config/myapp/*.yaml']`. parse_env: Whether environment variables are also parsed. **kwargs: Additional arguments to instantiate Trainer. """ @@ -143,7 +145,7 @@ def __init__( self.trainer_class = trainer_class self.trainer_kwargs = kwargs - self.init_parser(description, parse_env) + self.init_parser(description, default_config_files, parse_env) self.add_arguments_to_parser(self.parser) self.add_core_arguments_to_parser() self.parse_arguments() @@ -153,17 +155,20 @@ def __init__( def init_parser( self, description: str, + default_config_files: Optional[List[str]], parse_env: bool ): """Method that instantiates the argument parser Args: description: Description of the tool shown when running --help. + default_config_files: Default config file locations, e.g. :code:`['~/.config/myapp/*.yaml']`. parse_env: Whether environment variables are also parsed. """ self.parser = LightningArgumentParser( description=description, print_config='--print_config', + default_config_files=default_config_files, default_env=parse_env, env_prefix='PL' ) @@ -187,21 +192,31 @@ def add_core_arguments_to_parser(self): self.parser.add_datamodule_args(self.datamodule_class, 'data') def parse_arguments(self): - """Parses command line arguments and stores it in self.config""" - self.config = self.parser.parse_args() + """Parses command line arguments and stores it in self.config_save and self.config_init""" + self.config_save = self.parser.parse_args() + self.config_init = self.parser.instantiate_subclasses(self.config_save) def instantiate_classes(self): """Instantiates the classes using settings from self.config and prepares fit_kwargs""" - # Instantiate model - self.model = self.model_class(**self.config.get('model', {})) - # Instantiate datamodule + self.instantiate_model() + self.instantiate_datamodule() + self.instantiate_trainer() + + def instantiate_model(self): + """Instantiates the model using self.config_init['model']""" + self.model = self.model_class(**self.config_init.get('model', {})) + + def instantiate_datamodule(self): + """Instantiates the datamodule using self.config_init['data']""" self.fit_kwargs = {'model': self.model} if self.datamodule_class is not None: - self.fit_kwargs['datamodule'] = self.datamodule_class(**self.config.get('data', {})) - # Instantiate trainer - self.trainer_kwargs.update(self.config['trainer']) + self.fit_kwargs['datamodule'] = self.datamodule_class(**self.config_init.get('data', {})) + + def instantiate_trainer(self): + """Instantiates the trainer using self.config_init['trainer']""" + self.trainer_kwargs.update(self.config_init['trainer']) if self.save_config_callback is not None: - self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) + self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config_save)) self.trainer = self.trainer_class(**self.trainer_kwargs) def before_fit(self): diff --git a/requirements.txt b/requirements.txt index 6ea8ea0089aaf..5330e8b723e5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch>=1.3 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 -jsonargparse[signatures]>=3.0.0rc1 # New trainer_cli requirement +jsonargparse[signatures]>=3.0.0rc4 # New trainer_cli requirement tqdm>=4.41.0 fsspec>=0.8.0 tensorboard>=2.2.0 From 6a4ae27d44f3fd9e945fc7d72be9b81ae11e130f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 2 Dec 2020 08:12:21 +0100 Subject: [PATCH 06/35] - Renamed files to reflect new class name LightningCLI. - Added more methods to LightningCLI and made it more consistent. - Added unit test for LightningCLI using TrialMNISTDataModule and EvalModelTemplate. --- .../{trainer_cli.rst => lightning_cli.rst} | 96 ++++++++++--------- .../utilities/{trainer_cli.py => cli.py} | 77 ++++++++------- requirements.txt | 2 +- ...ainer_cli_new.py => test_lightning_cli.py} | 49 ++++++++-- 4 files changed, 135 insertions(+), 89 deletions(-) rename docs/source/{trainer_cli.rst => lightning_cli.rst} (68%) rename pytorch_lightning/utilities/{trainer_cli.py => cli.py} (83%) rename tests/trainer/{test_trainer_cli_new.py => test_lightning_cli.py} (83%) diff --git a/docs/source/trainer_cli.rst b/docs/source/lightning_cli.rst similarity index 68% rename from docs/source/trainer_cli.rst rename to docs/source/lightning_cli.rst index 1e92d46230f1e..f79f09f8ba390 100644 --- a/docs/source/trainer_cli.rst +++ b/docs/source/lightning_cli.rst @@ -3,7 +3,7 @@ from typing import List from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule - from pytorch_lightning.utilities.trainer_cli import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI original_fit = LightningCLI.fit LightningCLI.fit = lambda self: None @@ -64,7 +64,7 @@ simple as: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI LightningCLI(MyModel) @@ -83,9 +83,9 @@ to do this would be: # Run training using created configuration python trainer.py --config config.yaml -The call to the :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` -class takes care of parsing command line and config file options, instantiating -the classes, setting up a callback to save the config in the log directory and +The call to the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class +takes care of parsing command line and config file options, instantiating the +classes, setting up a callback to save the config in the log directory and finally running :func:`trainer.fit`. After multiple trainings with different configurations, a previous run can be @@ -100,7 +100,7 @@ class is required, the trainer tool just needs a small modification as follows: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI LightningCLI(MyModel, MyDataModule) @@ -178,22 +178,22 @@ format and for the example above would look as follows: amp_level: O2 ... -Note that for each class, model and trainer, there is a section each with the -init parameters of the class. This grouping is also used in the formatting of -the help shown previously. +Note that there is a section for each class (model and trainer) including all +the init parameters of the class. This grouping is also used in the formatting +of the help shown previously. Customizing LightningCLI ^^^^^^^^^^^^^^^^^^^^^^^^ The init parameters of the -:class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class can be used -to customize some things. +:class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to +customize some things. - :code:`save_config_callback`: By default is - :class:`~pytorch_lightning.utilities.trainer_cli.SaveConfigCallback` which is - the callback that saves the config to the log directory. It could be extended - for example to log the config as an artifact. + :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback` which is the + callback that saves the config to the log directory. It could be extended for + example to log the config as an artifact. - :code:`description`: The command line tool description shown in the help. @@ -207,33 +207,33 @@ to customize some things. - :code:`**kwargs`: All other keyword arguments are used to initialize the trainer class. Thus, this can be used for instance to set callbacks. -Even though :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` and its -init parameters can reduce boilerplate code to a minimum, clearly there are -cases in which it is not enough. The class is designed so that can be extended -to customize different parts of the command line tool. The argument parser class -used by :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` is -:class:`~pytorch_lightning.utilities.trainer_cli.LightningArgumentParser` which -is an extension of python's argparse, thus adding arguments can be done using -the :func:`add_argument` method. In contrast to argparse it has additional -methods to add arguments, for example :func:`add_class_arguments` adds all -arguments from the init of a class, though requiring parameters to have type -hints. For more details about this please refer to the `respective documentation +Even though :class:`~pytorch_lightning.utilities.cli.LightningCLI` and its init +parameters can reduce boilerplate code to a minimum, clearly there are cases in +which it is not enough. The class is designed so that can be extended to +customize different parts of the command line tool. The argument parser class +used by :class:`~pytorch_lightning.utilities.cli.LightningCLI` is +:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an +extension of python's argparse, thus adding arguments can be done using the +:func:`add_argument` method. In contrast to argparse it has additional methods +to add arguments, for example :func:`add_class_arguments` adds all arguments +from the init of a class, though requiring parameters to have type hints. For +more details about this please refer to the `respective documentation `_. -The :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class has the -:meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_arguments_to_parser` +The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. -The :class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class also has -two methods that can be used to run code before and after :code:`trainer.fit` is -executed: :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.before_fit` -and :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.after_fit`. A -simple example for these would be to send an email before and after fit. The -code would be something like: +The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two +methods that can be used to run code before and after :code:`trainer.fit` is +executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A simple +example for these would be to send an email before and after fit. The code would +be something like: .. testcode:: - from pytorch_lightning.utilities.trainer_cli import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI class MyLightningCLI(LightningCLI): @@ -261,15 +261,19 @@ for instantiating the trainer class can be found in :code:`self.config['trainer']`. For more advanced use cases, other methods of the -:class:`~pytorch_lightning.utilities.trainer_cli.LightningCLI` class could be -extended. The complete list of methods is: - -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.init_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.add_core_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.parse_arguments` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.instantiate_classes` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.before_fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.after_fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.fit` -- :meth:`~pytorch_lightning.utilities.trainer_cli.LightningCLI.run` +:class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be extended. +The complete list of methods is: + +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.init_parser` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_core_arguments_to_parser` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_parse_arguments` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.parse_arguments` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_instantiate_classes` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_classes` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_model` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.prepare_fit_kwargs` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_trainer` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.fit` +- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit` diff --git a/pytorch_lightning/utilities/trainer_cli.py b/pytorch_lightning/utilities/cli.py similarity index 83% rename from pytorch_lightning/utilities/trainer_cli.py rename to pytorch_lightning/utilities/cli.py index fa44395835b44..911dd882220f1 100644 --- a/pytorch_lightning/utilities/trainer_cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -51,7 +51,7 @@ def add_trainer_args( nested_key: Name of the nested namespace to store arguments. """ assert issubclass(trainer_class, Trainer) - self.add_class_arguments(trainer_class, nested_key) + return self.add_class_arguments(trainer_class, nested_key) def add_module_args( self, @@ -66,7 +66,7 @@ def add_module_args( nested_key: Name of the nested namespace to store arguments. """ assert issubclass(module_class, LightningModule) - self.add_class_arguments(module_class, nested_key) + return self.add_class_arguments(module_class, nested_key) def add_datamodule_args( self, @@ -81,19 +81,19 @@ def add_datamodule_args( nested_key: Name of the nested namespace to store arguments. """ assert issubclass(datamodule_class, LightningDataModule) - self.add_class_arguments(datamodule_class, nested_key) + return self.add_class_arguments(datamodule_class, nested_key) class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts""" def __init__(self, parser, config): - self.config_dump = parser.dump(config, skip_none=False) + self.parser = parser + self.config = config def on_train_start(self, trainer, pl_module): config_path = os.path.join(trainer.logger.log_dir, 'config.yaml') - with open(config_path, 'w') as outstream: - outstream.write(self.config_dump) + self.parser.save(self.config, config_path, skip_none=False) class LightningCLI: @@ -109,7 +109,7 @@ def __init__( **kwargs ): """ - Implementation of a simple configurable Trainer command line tool + Implementation of a configurable command line tool for pytorch-lightning Receives as input pytorch-lightning classes, which are instantiated using a parsed configuration file or command line args and then runs trainer.fit. @@ -117,12 +117,12 @@ def __init__( Example, first implement the trainer.py tool as:: from mymodels import MyModel - from pytorch_lightning.utilities.jsonargparse_utils import LightningCLI + from pytorch_lightning.utilities.cli import LightningCLI LightningCLI(MyModel) Then in a shell, run the tool with the desired configuration:: - $ python trainer.py --print-config > config.yaml + $ python trainer.py --print_config > config.yaml $ nano config.yaml # modify the config as desired $ python trainer.py --cfg config.yaml @@ -136,9 +136,6 @@ def __init__( parse_env: Whether environment variables are also parsed. **kwargs: Additional arguments to instantiate Trainer. """ - if 'callbacks' not in kwargs: - kwargs['callbacks'] = [] - self.model_class = model_class self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback @@ -148,9 +145,13 @@ def __init__( self.init_parser(description, default_config_files, parse_env) self.add_arguments_to_parser(self.parser) self.add_core_arguments_to_parser() + self.before_parse_arguments(self.parser) self.parse_arguments() + self.before_instantiate_classes() self.instantiate_classes() - self.run() + self.before_fit() + self.fit() + self.after_fit() def init_parser( self, @@ -167,16 +168,12 @@ def init_parser( """ self.parser = LightningArgumentParser( description=description, - print_config='--print_config', default_config_files=default_config_files, default_env=parse_env, env_prefix='PL' ) - def add_arguments_to_parser( - self, - parser: LightningArgumentParser - ): + def add_arguments_to_parser(self, parser: LightningArgumentParser): """Implement to add extra arguments to parser Args: @@ -191,23 +188,35 @@ def add_core_arguments_to_parser(self): if self.datamodule_class is not None: self.parser.add_datamodule_args(self.datamodule_class, 'data') + def before_parse_arguments(self, parser: LightningArgumentParser): + """Implement to run some code before parsing arguments + + Args: + parser: The argument parser object that will be used to parse + """ + pass + def parse_arguments(self): - """Parses command line arguments and stores it in self.config_save and self.config_init""" - self.config_save = self.parser.parse_args() - self.config_init = self.parser.instantiate_subclasses(self.config_save) + """Parses command line arguments and stores it in self.config""" + self.config = self.parser.parse_args() + + def before_instantiate_classes(self): + """Implement to run some code before instantiating the classes""" + pass def instantiate_classes(self): - """Instantiates the classes using settings from self.config and prepares fit_kwargs""" + """Instantiates the classes using settings from self.config""" + self.config_init = self.parser.instantiate_subclasses(self.config) self.instantiate_model() - self.instantiate_datamodule() + self.prepare_fit_kwargs() self.instantiate_trainer() def instantiate_model(self): """Instantiates the model using self.config_init['model']""" self.model = self.model_class(**self.config_init.get('model', {})) - def instantiate_datamodule(self): - """Instantiates the datamodule using self.config_init['data']""" + def prepare_fit_kwargs(self): + """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" self.fit_kwargs = {'model': self.model} if self.datamodule_class is not None: self.fit_kwargs['datamodule'] = self.datamodule_class(**self.config_init.get('data', {})) @@ -215,24 +224,20 @@ def instantiate_datamodule(self): def instantiate_trainer(self): """Instantiates the trainer using self.config_init['trainer']""" self.trainer_kwargs.update(self.config_init['trainer']) + if self.trainer_kwargs.get('callbacks') is None: + self.trainer_kwargs['callbacks'] = [] if self.save_config_callback is not None: - self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config_save)) + self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) self.trainer = self.trainer_class(**self.trainer_kwargs) def before_fit(self): """Implement to run some code before fit is started""" pass - def after_fit(self): - """Implement to run some code after fit has finished""" - pass - def fit(self): """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" - self.trainer.fit(**self.fit_kwargs) + self.fit_result = self.trainer.fit(**self.fit_kwargs) - def run(self): - """Runs self.before_fit, then self.fit and finally self.after_fit""" - self.before_fit() - self.fit() - self.after_fit() + def after_fit(self): + """Implement to run some code after fit has finished""" + pass diff --git a/requirements.txt b/requirements.txt index 5330e8b723e5a..2d7563dacee57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch>=1.3 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 -jsonargparse[signatures]>=3.0.0rc4 # New trainer_cli requirement +jsonargparse[signatures]>=3.0.1 # LightningCLI requirement tqdm>=4.41.0 fsspec>=0.8.0 tensorboard>=2.2.0 diff --git a/tests/trainer/test_trainer_cli_new.py b/tests/trainer/test_lightning_cli.py similarity index 83% rename from tests/trainer/test_trainer_cli_new.py rename to tests/trainer/test_lightning_cli.py index c5ebcfabd2ff3..5b44fc52e7cdc 100644 --- a/tests/trainer/test_trainer_cli_new.py +++ b/tests/trainer/test_lightning_cli.py @@ -13,8 +13,10 @@ # limitations under the License. import inspect +import os import pickle import sys +import yaml from argparse import Namespace from unittest import mock @@ -22,8 +24,11 @@ import torch import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule +from tests.base.develop_utils import reset_seed from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.trainer_cli import ( +from pytorch_lightning.utilities.cli import ( LightningArgumentParser, SaveConfigCallback, LightningCLI @@ -217,8 +222,8 @@ def test_init_from_argparse_args(cli_args, extra_args): {'model_param': 7}, {'limit_train_batches': 100}), ]) -def test_trainer_cli(cli_args, expected_model, expected_trainer, monkeypatch): - """Test that trainer_cli correctly instantiates model, trainer and calls fit.""" +def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): + """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" def fit(trainer, model): for k, v in model.expected_model.items(): @@ -230,10 +235,12 @@ def fit(trainer, model): save_callback[0].on_train_start(trainer, model) def on_train_start(callback, trainer, model): + config_dump = callback.parser.dump(callback.config, skip_none=False) for k, v in model.expected_model.items(): - assert f' {k}: {v}' in callback.config_dump + assert f' {k}: {v}' in config_dump for k, v in model.expected_trainer.items(): - assert f' {k}: {v}' in callback.config_dump + assert f' {k}: {v}' in config_dump + trainer.ran_asserts = True monkeypatch.setattr(Trainer, 'fit', fit) monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) @@ -247,4 +254,34 @@ def __init__(self, model_param: int): TestModel.expected_trainer = expected_trainer with mock.patch('sys.argv', ['any.py'] + cli_args): - LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) + cli = LightningCLI(TestModel, trainer_class=Trainer, save_config_callback=SaveConfigCallback) + assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts + + +def test_lightning_cli_with_trial_mnist_datamodule(tmpdir): + reset_seed() + + class TestModel(EvalModelTemplate): + pass + + TestModel.validation_step = None + TestModel.validation_step_end = None + TestModel.validation_epoch_end = None + + cli_args = [ + '--data.data_dir='+str(tmpdir), + '--trainer.default_root_dir='+str(tmpdir), + '--trainer.max_epochs=1', + '--trainer.weights_summary=null', + ] + + with mock.patch('sys.argv', ['trial.py'] + cli_args): + cli = LightningCLI(TestModel, TrialMNISTDataModule) + assert cli.fit_result == 1 + config_path = os.path.join(str(tmpdir), 'lightning_logs', 'version_0', 'config.yaml') + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert config['model'] == cli.config['model'] + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer'] From 6f0f2a094105bf1e7bad5728029d79c25f632fb6 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 8 Jan 2021 18:16:50 +0100 Subject: [PATCH 07/35] Work on LightningCLI: - Added cli to api_references.rst. - Added lightning_cli.rst in toctree after hyperparameters. - Added LightningCLI instance object to docs and why it could be used. - Added to docs explanation about Callbacks and other class types. - Implemented subclass mode to allow a single cli be used for multiple models/datamodules. - Changed LightningCLI init args to have trainer_kwargs and parser_kwargs. - Added instantiate_datamodule method. --- docs/source/api_references.rst | 1 + docs/source/index.rst | 1 + docs/source/lightning_cli.rst | 160 +++++++++++++++++++--------- pytorch_lightning/utilities/cli.py | 117 ++++++++++++-------- requirements.txt | 2 +- tests/trainer/test_lightning_cli.py | 62 +++++++++-- 6 files changed, 235 insertions(+), 108 deletions(-) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index e9520dea8045f..cbe11defc3a06 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -93,5 +93,6 @@ Utilities API :toctree: api :nosignatures: + cli argparse_utils seed diff --git a/docs/source/index.rst b/docs/source/index.rst index 650c7de6b2ab0..2678267e907d3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -100,6 +100,7 @@ PyTorch Lightning Documentation early_stopping fast_training hyperparameters + lightning_cli lr_finder multi_gpu multiple_loaders diff --git a/docs/source/lightning_cli.rst b/docs/source/lightning_cli.rst index f79f09f8ba390..80cccac719584 100644 --- a/docs/source/lightning_cli.rst +++ b/docs/source/lightning_cli.rst @@ -28,6 +28,9 @@ def send_email(address, message): pass + MyModelBaseClass = MyModel + MyDataModuleBaseClass = MyDataModule + .. testcleanup:: * LightningCLI.fit = original_fit @@ -66,12 +69,12 @@ simple as: from pytorch_lightning.utilities.cli import LightningCLI - LightningCLI(MyModel) + cli = LightningCLI(MyModel) The help of the tool describing all configurable options and default values can be shown by running :code:`python trainer.py --help`. Default options can be changed by providing individual command line arguments. However, it is better -practice to create a configuration file and provide this to the trainer. A way +practice to create a configuration file and provide this to the tool. A way to do this would be: .. code-block:: bash @@ -83,13 +86,16 @@ to do this would be: # Run training using created configuration python trainer.py --config config.yaml -The call to the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class -takes care of parsing command line and config file options, instantiating the -classes, setting up a callback to save the config in the log directory and -finally running :func:`trainer.fit`. +The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` +class takes care of parsing command line and config file options, instantiating +the classes, setting up a callback to save the config in the log directory and +finally running :func:`trainer.fit`. The resulting object :code:`cli` can be +used for instance to get the result of fit, i.e., :code:`cli.fit_result`. -After multiple trainings with different configurations, a previous run can be -trivially reproduced by using the config in the respective log directory, e.g.: +After multiple trainings with different configurations, each run will have in +its respective log directory a :code:`config.yaml` file. This file can be used +for reference to know in detail all the settings that were used for each +particular run, and also could be used to trivially reproduce a training, e.g.: .. code-block:: bash @@ -102,7 +108,7 @@ class is required, the trainer tool just needs a small modification as follows: from pytorch_lightning.utilities.cli import LightningCLI - LightningCLI(MyModel, MyDataModule) + cli = LightningCLI(MyModel, MyDataModule) The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the docstring could be the one below. Note @@ -183,35 +189,99 @@ the init parameters of the class. This grouping is also used in the formatting of the help shown previously. +Trainer Callbacks and arguments with class type +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A very important argument of the +:class:`~pytorch_lightning.trainer.trainer.Trainer` class is the +:code:`callbacks`. In contrast to other more simple arguments which just require +numbers or strings, :code:`callbacks` expects a list of instances of subclasses +of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of +argument in a config file, each callback must be given as a dictionary including +a :code:`class_path` entry with an import path of the class, and optionally an +:code:`init_args` entry with arguments required to instantiate it. Therefore, a +simple configuration file example that defines a couple of callbacks is the +following: + +.. code-block:: yaml + + trainer: + callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + init_args: + ... + +Similar to the callbacks, any arguments in +:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +:class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that +have as type hint a class can be configured the same way using +:code:`class_path` and :code:`init_args`. + + +Multiple models and/or datasets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` +works only for a single model and datamodule class. However, there are many +cases in which the objective is to easily be able to run many experiments for +multiple models and datasets. For these cases the tool can be configured such +that a model and/or a datamodule is specified by an import path and init +arguments. For example, with a tool implemented as: + +.. testcode:: + + from pytorch_lightning.utilities.cli import LightningCLI + + cli = LightningCLI( + MyModelBaseClass, + MyDataModuleBaseClass, + subclass_mode_model=True, + subclass_mode_data=True + ) + +A possible config file could be as follows: + +.. code-block:: yaml + + model: + class_path: mycode.mymodels.MyModel + init_args: + decoder_layers: + - 2 + - 4 + encoder_layers: 12 + data: + class_path: mycode.mydatamodules.MyDataModule + init_args: + ... + trainer: + callbacks: + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + patience: 5 + ... + +Only model classes that are a subclass of :code:`MyModelBaseClass` would be +allowed, and similarly only subclasses of :code:`MyDataModuleBaseClass`. + + Customizing LightningCLI ^^^^^^^^^^^^^^^^^^^^^^^^ The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to -customize some things. - -- :code:`save_config_callback`: By default is - :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback` which is the - callback that saves the config to the log directory. It could be extended for - example to log the config as an artifact. - -- :code:`description`: The command line tool description shown in the help. - -- :code:`parse_env`: A boolean that can be used to enable parsing of environment - variables. With this for instance the :code:`PL_TRAINER__MAX_EPOCHS` - environment variable if set would be used to override the default - :code:`max_epochs` of the trainer. Similarly options for the data module could - be set using variables that start with :code:`PL_DATA_` and likewise for the - modules. - -- :code:`**kwargs`: All other keyword arguments are used to initialize the - trainer class. Thus, this can be used for instance to set callbacks. - -Even though :class:`~pytorch_lightning.utilities.cli.LightningCLI` and its init -parameters can reduce boilerplate code to a minimum, clearly there are cases in -which it is not enough. The class is designed so that can be extended to -customize different parts of the command line tool. The argument parser class -used by :class:`~pytorch_lightning.utilities.cli.LightningCLI` is +customize some things, namely: the description of the tool, enabling parsing of +environment variables and additional arguments to instantiate the trainer and +configuration parser. + +Nevertheless the init arguments are not enough for many use cases. For this +reason the class is designed so that can be extended to customize different +parts of the command line tool. The argument parser class used by +:class:`~pytorch_lightning.utilities.cli.LightningCLI` is :class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods @@ -227,9 +297,9 @@ configuration is stored in the :code:`config` attribute of the class instance. The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and -:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A simple -example for these would be to send an email before and after fit. The code would -be something like: +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic +example for these would be to send an email before and after the execution of +fit. The code would be something like: .. testcode:: @@ -252,7 +322,7 @@ be something like: message='trainer.fit finished' ) - MyLightningCLI(MyModel) + cli = MyLightningCLI(MyModel) Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It has the same structure as the yaml @@ -262,18 +332,4 @@ for instantiating the trainer class can be found in For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be extended. -The complete list of methods is: - -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.init_parser` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_core_arguments_to_parser` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_parse_arguments` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.parse_arguments` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_instantiate_classes` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_classes` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_model` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.prepare_fit_kwargs` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.instantiate_trainer` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.fit` -- :meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit` +For further information have a look at the corresponding API reference. diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 911dd882220f1..ec66cdff71880 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from typing import Type, List, Optional -from jsonargparse import ArgumentParser, ActionConfigFile +from typing import Type, List, Optional, Dict, Any +from jsonargparse import ArgumentParser, ActionConfigFile, SUPPRESS from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule @@ -30,7 +30,11 @@ def __init__( parse_as_dict: bool = True, **kwargs ): - """Initialize argument parser that supports configuration file input""" + """Initialize argument parser that supports configuration file input + + For full details of accepted arguments see `ArgumentParser.__init__ + `_. + """ super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) self.add_argument( '--config', @@ -56,7 +60,8 @@ def add_trainer_args( def add_module_args( self, module_class: Type[LightningModule], - nested_key: str = 'module' + nested_key: str = 'module', + subclass_mode: bool = False ): """ Adds arguments from a module class to a nested key of the parser @@ -64,14 +69,18 @@ def add_module_args( Args: module_class: A LightningModule class. nested_key: Name of the nested namespace to store arguments. + subclass_mode: Whether allow any subclass of the given class. """ assert issubclass(module_class, LightningModule) + if subclass_mode: + return self.add_subclass_arguments(module_class, nested_key) return self.add_class_arguments(module_class, nested_key) def add_datamodule_args( self, datamodule_class: Type[LightningDataModule], - nested_key: str = 'data' + nested_key: str = 'data', + subclass_mode: bool = False ): """ Adds arguments from a datamodule class to a nested key of the parser @@ -79,8 +88,11 @@ def add_datamodule_args( Args: datamodule_class: A LightningDataModule class. nested_key: Name of the nested namespace to store arguments. + subclass_mode: Whether allow any subclass of the given class. """ assert issubclass(datamodule_class, LightningDataModule) + if subclass_mode: + return self.add_subclass_arguments(datamodule_class, nested_key) return self.add_class_arguments(datamodule_class, nested_key) @@ -103,16 +115,23 @@ def __init__( datamodule_class: Type[LightningDataModule] = None, save_config_callback: Type[Callback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, + trainer_kwargs: Dict[str, Any] = None, description: str = 'pytorch-lightning trainer command line tool', - default_config_files: List[str] = None, - parse_env: bool = False, - **kwargs + env_prefix: str = 'PL', + env_parse: bool = False, + parser_kwargs: Dict[str, Any] = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False ): """ Implementation of a configurable command line tool for pytorch-lightning - Receives as input pytorch-lightning classes, which are instantiated using a - parsed configuration file or command line args and then runs trainer.fit. + Receives as input pytorch-lightning classes, which are instantiated + using a parsed configuration file and/or command line args and then runs + trainer.fit. Parsing of configuration from environment variables can + be enabled by setting :code:`env_parse=True`. A full configuration yaml would + be parsed from :code:`PL_CONFIG` if set. Individual settings are so parsed from + variables named for example :code:`PL_TRAINER__MAX_EPOCHS`. Example, first implement the trainer.py tool as:: @@ -131,47 +150,43 @@ def __init__( datamodule_class: An optional LightningDataModule class. save_config_callback: A callback class to save the training config. trainer_class: An optional extension of the Trainer class. + trainer_kwargs: Additional arguments to instantiate Trainer. description: Description of the tool shown when running --help. - default_config_files: Default config file locations, e.g. :code:`['~/.config/myapp/*.yaml']`. - parse_env: Whether environment variables are also parsed. - **kwargs: Additional arguments to instantiate Trainer. + env_prefix: Prefix for environment variables. + env_parse: Whether environment variable parsing is enabled. + parser_kwargs: Additional arguments to instantiate LightningArgumentParser. + subclass_mode_model: Whether model can be any `subclass `_ of the given class. + subclass_mode_data: Whether datamodule can be any `subclass `_ of the given class. """ self.model_class = model_class self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback self.trainer_class = trainer_class - self.trainer_kwargs = kwargs - - self.init_parser(description, default_config_files, parse_env) + self.trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs + self.subclass_mode_model = subclass_mode_model + self.subclass_mode_data = subclass_mode_data + self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs + self.parser_kwargs.update({ + 'description': description, + 'env_prefix': env_prefix, + 'default_env': env_parse + }) + + self.init_parser() self.add_arguments_to_parser(self.parser) self.add_core_arguments_to_parser() self.before_parse_arguments(self.parser) self.parse_arguments() self.before_instantiate_classes() self.instantiate_classes() + self.prepare_fit_kwargs() self.before_fit() self.fit() self.after_fit() - def init_parser( - self, - description: str, - default_config_files: Optional[List[str]], - parse_env: bool - ): - """Method that instantiates the argument parser - - Args: - description: Description of the tool shown when running --help. - default_config_files: Default config file locations, e.g. :code:`['~/.config/myapp/*.yaml']`. - parse_env: Whether environment variables are also parsed. - """ - self.parser = LightningArgumentParser( - description=description, - default_config_files=default_config_files, - default_env=parse_env, - env_prefix='PL' - ) + def init_parser(self): + """Method that instantiates the argument parser""" + self.parser = LightningArgumentParser(**self.parser_kwargs) def add_arguments_to_parser(self, parser: LightningArgumentParser): """Implement to add extra arguments to parser @@ -184,9 +199,9 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): def add_core_arguments_to_parser(self): """Adds arguments from the core classes to the parser""" self.parser.add_trainer_args(self.trainer_class, 'trainer') - self.parser.add_module_args(self.model_class, 'model') + self.parser.add_module_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: - self.parser.add_datamodule_args(self.datamodule_class, 'data') + self.parser.add_datamodule_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) def before_parse_arguments(self, parser: LightningArgumentParser): """Implement to run some code before parsing arguments @@ -208,18 +223,24 @@ def instantiate_classes(self): """Instantiates the classes using settings from self.config""" self.config_init = self.parser.instantiate_subclasses(self.config) self.instantiate_model() - self.prepare_fit_kwargs() + self.instantiate_datamodule() self.instantiate_trainer() def instantiate_model(self): """Instantiates the model using self.config_init['model']""" - self.model = self.model_class(**self.config_init.get('model', {})) - - def prepare_fit_kwargs(self): - """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" - self.fit_kwargs = {'model': self.model} - if self.datamodule_class is not None: - self.fit_kwargs['datamodule'] = self.datamodule_class(**self.config_init.get('data', {})) + if self.subclass_mode_model: + self.model = self.config_init['model'] + else: + self.model = self.model_class(**self.config_init.get('model', {})) + + def instantiate_datamodule(self): + """Instantiates the datamodule using self.config_init['data'] if given""" + if self.datamodule_class is None: + self.datamodule = None + elif self.subclass_mode_data: + self.datamodule = self.config_init['data'] + else: + self.datamodule = self.datamodule_class(**self.config_init.get('data', {})) def instantiate_trainer(self): """Instantiates the trainer using self.config_init['trainer']""" @@ -230,6 +251,12 @@ def instantiate_trainer(self): self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) self.trainer = self.trainer_class(**self.trainer_kwargs) + def prepare_fit_kwargs(self): + """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" + self.fit_kwargs = {'model': self.model} + if self.datamodule is not None: + self.fit_kwargs['datamodule'] = self.datamodule + def before_fit(self): """Implement to run some code before fit is started""" pass diff --git a/requirements.txt b/requirements.txt index 2d7563dacee57..3d9effc77ed8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch>=1.3 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 -jsonargparse[signatures]>=3.0.1 # LightningCLI requirement +jsonargparse[signatures]>=3.3.1 # LightningCLI requirement tqdm>=4.41.0 fsspec>=0.8.0 tensorboard>=2.2.0 diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index 5b44fc52e7cdc..e11c8f762c871 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -258,25 +258,67 @@ def __init__(self, model_param: int): assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts -def test_lightning_cli_with_trial_mnist_datamodule(tmpdir): - reset_seed() +class TestLightningCLI(LightningCLI): + def before_fit(self): + for key in ['validation_step', + 'validation_step_end', + 'validation_epoch_end', + 'test_step', + 'test_step_end', + 'test_epoch_end']: + setattr(self.model, key, None) - class TestModel(EvalModelTemplate): - pass - TestModel.validation_step = None - TestModel.validation_step_end = None - TestModel.validation_epoch_end = None +def test_lightning_cli_mnist_args(tmpdir): cli_args = [ - '--data.data_dir='+str(tmpdir), - '--trainer.default_root_dir='+str(tmpdir), + '--data.data_dir=' + str(tmpdir), + '--trainer.default_root_dir=' + str(tmpdir), '--trainer.max_epochs=1', '--trainer.weights_summary=null', ] with mock.patch('sys.argv', ['trial.py'] + cli_args): - cli = LightningCLI(TestModel, TrialMNISTDataModule) + cli = TestLightningCLI(EvalModelTemplate, TrialMNISTDataModule) + assert cli.fit_result == 1 + config_path = os.path.join(str(tmpdir), 'lightning_logs', 'version_0', 'config.yaml') + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert config['model'] == cli.config['model'] + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer'] + + +def test_lightning_cli_mnist_config_and_subclass_mode(tmpdir): + + config = { + 'model': { + 'class_path': 'tests.base.EvalModelTemplate', + }, + 'data': { + 'class_path': 'tests.base.datamodules.TrialMNISTDataModule', + 'init_args': { + 'data_dir': str(tmpdir), + }, + }, + 'trainer': { + 'default_root_dir': str(tmpdir), + 'max_epochs': 1, + 'weights_summary': None, + }, + } + config_path = os.path.join(str(tmpdir), 'config.yaml') + with open(config_path, 'w') as f: + f.write(yaml.dump(config)) + + with mock.patch('sys.argv', ['trial.py', '--config', config_path]): + cli = TestLightningCLI( + EvalModelTemplate, + TrialMNISTDataModule, + subclass_mode_model=True, + subclass_mode_data=True + ) assert cli.fit_result == 1 config_path = os.path.join(str(tmpdir), 'lightning_logs', 'version_0', 'config.yaml') assert os.path.isfile(config_path) From e91a351ed0cca1da0aa845bf24baa9632b78f6cc Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 8 Jan 2021 20:51:01 +0100 Subject: [PATCH 08/35] Work on LightningCLI: - Moved jsonargparse requirement to extras. - Fix for testcode in lightning_cli.rst. - Fix for tpu_cores issue in test_lightning_cli.py. - Fix pep8speaks long lines in cli.py. - Added LightningCLI feature to CHANGELOG.md. --- CHANGELOG.md | 3 +++ docs/source/lightning_cli.rst | 5 +++++ pytorch_lightning/utilities/cli.py | 8 ++++++-- requirements.txt | 1 - requirements/extra.txt | 1 + tests/trainer/test_lightning_cli.py | 5 +++-- 6 files changed, 18 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7808e8d61c83f..feb950cc097e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli code. ([#4492](https://github.com/PyTorchLightning/pytorch-lightning/pull/4492)) + + ### Changed - Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) diff --git a/docs/source/lightning_cli.rst b/docs/source/lightning_cli.rst index 80cccac719584..c804094f93079 100644 --- a/docs/source/lightning_cli.rst +++ b/docs/source/lightning_cli.rst @@ -1,5 +1,6 @@ .. testsetup:: * + from unittest import mock from typing import List from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule @@ -31,9 +32,13 @@ MyModelBaseClass = MyModel MyDataModuleBaseClass = MyDataModule + mock_argv = mock.patch("sys.argv", ["any.py"]) + mock_argv.start() + .. testcleanup:: * LightningCLI.fit = original_fit + mock_argv.stop() Lightning CLI and config files diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index ec66cdff71880..ca2f969ac1dc3 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -155,8 +155,12 @@ def __init__( env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate LightningArgumentParser. - subclass_mode_model: Whether model can be any `subclass `_ of the given class. - subclass_mode_data: Whether datamodule can be any `subclass `_ of the given class. + subclass_mode_model: Whether model can be any `subclass + `_ of the + given class. + subclass_mode_data: Whether datamodule can be any `subclass + `_ of the + given class. """ self.model_class = model_class self.datamodule_class = datamodule_class diff --git a/requirements.txt b/requirements.txt index 311fd51d6ec9a..2dd5378649851 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ torch>=1.3 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 -jsonargparse[signatures]>=3.3.1 # LightningCLI requirement tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 diff --git a/requirements/extra.txt b/requirements/extra.txt index 3f14b1e5910dd..187536fdc1a36 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,3 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip +jsonargparse[signatures]>=3.3.1 # LightningCLI requirement diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index e11c8f762c871..a855e7bcea20c 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -26,7 +26,7 @@ import tests.base.develop_utils as tutils from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule -from tests.base.develop_utils import reset_seed +from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.utilities.cli import ( LightningArgumentParser, @@ -150,7 +150,8 @@ def test_parse_args_parsing(cli_args, expected): for k, v in expected.items(): assert getattr(args, k) == v - assert Trainer.from_argparse_args(args) + if 'tpu_cores' not in expected or _TPU_AVAILABLE: + assert Trainer.from_argparse_args(args) @pytest.mark.parametrize(['cli_args', 'expected', 'instantiate'], [ From 7a76ea655450f8c06b79ae10b454cb8c368279f8 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 8 Jan 2021 21:14:51 +0100 Subject: [PATCH 09/35] Work on LightningCLI: - Skip testcode in lightning_cli.rst if jsonargparse not available. - Removed unused imports. --- docs/source/conf.py | 1 + docs/source/lightning_cli.rst | 1 + pytorch_lightning/utilities/cli.py | 4 ++-- tests/trainer/test_lightning_cli.py | 1 - 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b2e46395c7787..69fcd1d38d48f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -369,6 +369,7 @@ def package_list_from_file(file): ) _TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +_JSONARGPARSE_AVAILABLE = importlib.util.find_spec("jsonargparse") is not None """ coverage_skip_undoc_in_source = True diff --git a/docs/source/lightning_cli.rst b/docs/source/lightning_cli.rst index c804094f93079..51dff7746ce2c 100644 --- a/docs/source/lightning_cli.rst +++ b/docs/source/lightning_cli.rst @@ -1,4 +1,5 @@ .. testsetup:: * + :skipif: not _JSONARGPARSE_AVAILABLE from unittest import mock from typing import List diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index ca2f969ac1dc3..cc9dfb3447681 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from typing import Type, List, Optional, Dict, Any -from jsonargparse import ArgumentParser, ActionConfigFile, SUPPRESS +from typing import Type, Dict, Any +from jsonargparse import ArgumentParser, ActionConfigFile from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index a855e7bcea20c..f01463bdca32b 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import os import pickle import sys From 03be5c418271b0aa5e3821485a37101b219a3b48 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Sun, 10 Jan 2021 23:02:27 -0500 Subject: [PATCH 10/35] Swap instantiation of datamodule and model in LightningCLI. --- pytorch_lightning/utilities/cli.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index cc9dfb3447681..1a1c8dc2f5ad7 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -226,17 +226,10 @@ def before_instantiate_classes(self): def instantiate_classes(self): """Instantiates the classes using settings from self.config""" self.config_init = self.parser.instantiate_subclasses(self.config) - self.instantiate_model() self.instantiate_datamodule() + self.instantiate_model() self.instantiate_trainer() - def instantiate_model(self): - """Instantiates the model using self.config_init['model']""" - if self.subclass_mode_model: - self.model = self.config_init['model'] - else: - self.model = self.model_class(**self.config_init.get('model', {})) - def instantiate_datamodule(self): """Instantiates the datamodule using self.config_init['data'] if given""" if self.datamodule_class is None: @@ -246,6 +239,13 @@ def instantiate_datamodule(self): else: self.datamodule = self.datamodule_class(**self.config_init.get('data', {})) + def instantiate_model(self): + """Instantiates the model using self.config_init['model']""" + if self.subclass_mode_model: + self.model = self.config_init['model'] + else: + self.model = self.model_class(**self.config_init.get('model', {})) + def instantiate_trainer(self): """Instantiates the trainer using self.config_init['trainer']""" self.trainer_kwargs.update(self.config_init['trainer']) From 69904d7212c928b014ff617d202ca94d053a08eb Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 12 Jan 2021 22:29:36 +0100 Subject: [PATCH 11/35] Changed LightningArgumentParser add args methods to a single one add_lightning_class_args --- pytorch_lightning/utilities/cli.py | 62 ++++++++--------------------- tests/trainer/test_lightning_cli.py | 10 ++--- 2 files changed, 21 insertions(+), 51 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 1a1c8dc2f5ad7..6d100bcc0a99b 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Type, Dict, Any +from typing import Type, Dict, Any, Union from jsonargparse import ArgumentParser, ActionConfigFile from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule @@ -42,58 +42,24 @@ def __init__( help='Path to a configuration file in json or yaml format.' ) - def add_trainer_args( + def add_lightning_class_args( self, - trainer_class: Type[Trainer] = Trainer, - nested_key: str = 'trainer' - ): - """ - Adds arguments from a trainer class to a nested key of the parser - - Args: - trainer_class: Optional extension of the Trainer class. - nested_key: Name of the nested namespace to store arguments. - """ - assert issubclass(trainer_class, Trainer) - return self.add_class_arguments(trainer_class, nested_key) - - def add_module_args( - self, - module_class: Type[LightningModule], - nested_key: str = 'module', - subclass_mode: bool = False - ): - """ - Adds arguments from a module class to a nested key of the parser - - Args: - module_class: A LightningModule class. - nested_key: Name of the nested namespace to store arguments. - subclass_mode: Whether allow any subclass of the given class. - """ - assert issubclass(module_class, LightningModule) - if subclass_mode: - return self.add_subclass_arguments(module_class, nested_key) - return self.add_class_arguments(module_class, nested_key) - - def add_datamodule_args( - self, - datamodule_class: Type[LightningDataModule], - nested_key: str = 'data', + lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], + nested_key: str, subclass_mode: bool = False ): """ - Adds arguments from a datamodule class to a nested key of the parser + Adds arguments from a lightning class to a nested key of the parser Args: - datamodule_class: A LightningDataModule class. + lightning_class: Any subclass of {Trainer,LightningModule,LightningDataModule}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. """ - assert issubclass(datamodule_class, LightningDataModule) + assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule)) if subclass_mode: - return self.add_subclass_arguments(datamodule_class, nested_key) - return self.add_class_arguments(datamodule_class, nested_key) + return self.add_subclass_arguments(lightning_class, nested_key) + return self.add_class_arguments(lightning_class, nested_key) class SaveConfigCallback(Callback): @@ -162,6 +128,10 @@ def __init__( `_ of the given class. """ + assert issubclass(trainer_class, Trainer) + assert issubclass(model_class, LightningModule) + if datamodule_class is not None: + assert issubclass(datamodule_class, LightningDataModule) self.model_class = model_class self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback @@ -202,10 +172,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): def add_core_arguments_to_parser(self): """Adds arguments from the core classes to the parser""" - self.parser.add_trainer_args(self.trainer_class, 'trainer') - self.parser.add_module_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) + self.parser.add_lightning_class_args(self.trainer_class, 'trainer') + self.parser.add_lightning_class_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: - self.parser.add_datamodule_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) + self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) def before_parse_arguments(self, parser: LightningArgumentParser): """Implement to run some code before parsing arguments diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index f01463bdca32b..a49629bd690df 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -63,7 +63,7 @@ def test_add_argparse_args_redefined(cli_args): tests the Trainer initialization correctness. """ parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_trainer_args(Trainer, None) + parser.add_lightning_class_args(Trainer, None) args = parser.parse_args(cli_args) @@ -94,7 +94,7 @@ def _raise(): raise _UnkArgError parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_trainer_args(Trainer, None) + parser.add_lightning_class_args(Trainer, None) monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True) @@ -143,7 +143,7 @@ def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(' ') if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_trainer_args(Trainer, None) + parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -167,7 +167,7 @@ def test_parse_args_parsing(cli_args, expected): def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_trainer_args(Trainer, None) + parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() @@ -186,7 +186,7 @@ def test_parse_args_parsing_gpus(cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" cli_args = cli_args.split(' ') if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) - parser.add_trainer_args(Trainer, None) + parser.add_lightning_class_args(Trainer, None) with mock.patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() From a726c3fb49a56208412009cdb7b5115c4c8249e0 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 15 Jan 2021 17:16:38 +0100 Subject: [PATCH 12/35] Made pytorch_lightning.utilities.cli importable even when jsonargparse not available --- pytorch_lightning/utilities/cli.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6d100bcc0a99b..3c132f8e80978 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,14 +13,20 @@ # limitations under the License. import os +import importlib from typing import Type, Dict, Any, Union -from jsonargparse import ArgumentParser, ActionConfigFile from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.callbacks import Callback +if importlib.util.find_spec("jsonargparse") is not None: + from jsonargparse import ArgumentParser, ActionConfigFile +else: + ArgumentParser = object + + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" From 2cd09d58f095255ceb0ad595394136c073f76f91 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 15 Jan 2021 17:30:07 +0100 Subject: [PATCH 13/35] - Fix "Check valid import formatting with isort". - Made save_config_callback type more specific. --- pytorch_lightning/utilities/cli.py | 2 +- tests/trainer/test_lightning_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3c132f8e80978..3201d9168956d 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -85,7 +85,7 @@ def __init__( self, model_class: Type[LightningModule], datamodule_class: Type[LightningDataModule] = None, - save_config_callback: Type[Callback] = SaveConfigCallback, + save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, trainer_kwargs: Dict[str, Any] = None, description: str = 'pytorch-lightning trainer command line tool', diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index a49629bd690df..e38796c78d602 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -15,12 +15,12 @@ import os import pickle import sys -import yaml from argparse import Namespace from unittest import mock import pytest import torch +import yaml import tests.base.develop_utils as tutils from tests.base import EvalModelTemplate From 04fdd5ae3777a8a3d52cb61cee911c07ab7a98d7 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 15 Jan 2021 17:35:04 +0100 Subject: [PATCH 14/35] Fix "Check valid import formatting with isort" --- tests/trainer/test_lightning_cli.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index e38796c78d602..0a3ef00384f27 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -23,15 +23,11 @@ import yaml import tests.base.develop_utils as tutils -from tests.base import EvalModelTemplate -from tests.base.datamodules import TrialMNISTDataModule from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.cli import ( - LightningArgumentParser, - SaveConfigCallback, - LightningCLI -) +from pytorch_lightning.utilities.cli import LightningArgumentParser, SaveConfigCallback, LightningCLI +from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule @mock.patch('argparse.ArgumentParser.parse_args') From a4a4a557ce996f651374b9f4727396a16d9ecde8 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 15 Jan 2021 17:40:46 +0100 Subject: [PATCH 15/35] Fix "Check valid import formatting with isort" --- tests/trainer/test_lightning_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index 0a3ef00384f27..c49492a1486f1 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -23,9 +23,9 @@ import yaml import tests.base.develop_utils as tutils +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.cli import LightningArgumentParser, SaveConfigCallback, LightningCLI +from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule From 1403501ac955bc3f79c5895a7d7732987c532cfd Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 21 Jan 2021 22:50:30 +0100 Subject: [PATCH 16/35] Fix "Check valid import formatting with isort" --- pytorch_lightning/utilities/cli.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3201d9168956d..d212505244ef1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import importlib -from typing import Type, Dict, Any, Union -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.callbacks import Callback +import os +from typing import Any, Dict, Type, Union +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.trainer import Trainer if importlib.util.find_spec("jsonargparse") is not None: - from jsonargparse import ArgumentParser, ActionConfigFile + from jsonargparse import ActionConfigFile, ArgumentParser else: ArgumentParser = object From dcbd3d7044e0c4a2b42995046286856b3d126b4f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 21 Jan 2021 22:54:20 +0100 Subject: [PATCH 17/35] Fix "Check valid import formatting with isort" --- pytorch_lightning/utilities/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d212505244ef1..bbeed2b6529ec 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -14,8 +14,8 @@ import importlib import os - from typing import Any, Dict, Type, Union + from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule From 4d7f5b96c774a65fe0978c035eaffa622d4110bf Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 1 Feb 2021 13:45:02 -0500 Subject: [PATCH 18/35] Update to reflect change in structure in docs/sources --- docs/source/{ => common}/lightning_cli.rst | 0 docs/source/index.rst | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/source/{ => common}/lightning_cli.rst (100%) diff --git a/docs/source/lightning_cli.rst b/docs/source/common/lightning_cli.rst similarity index 100% rename from docs/source/lightning_cli.rst rename to docs/source/common/lightning_cli.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 9ebec882dabf7..40ffbe5d2fc9f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,7 +102,7 @@ PyTorch Lightning Documentation common/early_stopping common/fast_training common/hyperparameters - lightning_cli + common/lightning_cli advanced/lr_finder advanced/multi_gpu advanced/multiple_loaders From a4da435d36d561141ce9fae5a55f56325251876b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 16 Feb 2021 01:32:05 +0100 Subject: [PATCH 19/35] Update to latest changes --- pytorch_lightning/utilities/cli.py | 29 ++-- tests/trainer/test_lightning_cli.py | 250 +++++++++++++++------------- 2 files changed, 143 insertions(+), 136 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index bbeed2b6529ec..d7da22d855cd4 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import os from typing import Any, Dict, Type, Union @@ -20,8 +18,10 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities import _module_available -if importlib.util.find_spec("jsonargparse") is not None: +_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") +if _JSONARGPARSE_AVAILABLE: from jsonargparse import ActionConfigFile, ArgumentParser else: ArgumentParser = object @@ -30,22 +30,20 @@ class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" - def __init__( - self, - *args, - parse_as_dict: bool = True, - **kwargs - ): + def __init__(self, *args, parse_as_dict: bool = True, **kwargs): """Initialize argument parser that supports configuration file input For full details of accepted arguments see `ArgumentParser.__init__ `_. """ + if not _JSONARGPARSE_AVAILABLE: + raise ModuleNotFoundError( + '`jsonargparse` is not installed but it is required for the CLI.' + ' Install it with `pip install jsonargparse[signatures]`.' + ) super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) self.add_argument( - '--config', - action=ActionConfigFile, - help='Path to a configuration file in json or yaml format.' + '--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.' ) def add_lightning_class_args( @@ -81,6 +79,7 @@ def on_train_start(self, trainer, pl_module): class LightningCLI: + def __init__( self, model_class: Type[LightningModule], @@ -146,11 +145,7 @@ def __init__( self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs - self.parser_kwargs.update({ - 'description': description, - 'env_prefix': env_prefix, - 'default_env': env_parse - }) + self.parser_kwargs.update({'description': description, 'env_prefix': env_prefix, 'default_env': env_parse}) self.init_parser() self.add_arguments_to_parser(self.parser) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index c49492a1486f1..e05b90ac6c32c 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -19,15 +19,12 @@ from unittest import mock import pytest -import torch import yaml -import tests.base.develop_utils as tutils from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback -from tests.base import EvalModelTemplate -from tests.base.datamodules import TrialMNISTDataModule +from tests.helpers import BoringDataModule, BoringModel @mock.patch('argparse.ArgumentParser.parse_args') @@ -35,12 +32,8 @@ def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) - # logger file to get meta - logger = tutils.get_default_logger(tmpdir) - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) args = parser.parse_args([]) - args.logger = logger args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) @@ -49,11 +42,7 @@ def test_default_args(mock_argparse, tmpdir): assert trainer.max_epochs == 5 -@pytest.mark.parametrize('cli_args', [ - ['--accumulate_grad_batches=22'], - ['--weights_save_path=./'], - [] -]) +@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness. @@ -76,10 +65,7 @@ def test_add_argparse_args_redefined(cli_args): assert isinstance(trainer, Trainer) -@pytest.mark.parametrize('cli_args', [ - ['--callbacks=1', '--logger'], - ['--foo', '--bar=1'] -]) +@pytest.mark.parametrize('cli_args', [['--callbacks=1', '--logger'], ['--foo', '--bar=1']]) def test_add_argparse_args_redefined_error(cli_args, monkeypatch): """Asserts error raised in case of passing not default cli arguments.""" @@ -98,43 +84,64 @@ def _raise(): parser.parse_args(cli_args) -@pytest.mark.parametrize(['cli_args', 'expected'], [ - pytest.param('--auto_lr_find=True --auto_scale_batch_size=power', - {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), - pytest.param('--auto_lr_find any_string --auto_scale_batch_size ON', - {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}), - pytest.param('--auto_lr_find=Yes --auto_scale_batch_size=On', - {'auto_lr_find': True, 'auto_scale_batch_size': True}), - pytest.param('--auto_lr_find Off --auto_scale_batch_size No', - {'auto_lr_find': False, 'auto_scale_batch_size': False}), - pytest.param('--auto_lr_find TRUE --auto_scale_batch_size FALSE', - {'auto_lr_find': True, 'auto_scale_batch_size': False}), - pytest.param('--tpu_cores=8', - {'tpu_cores': 8}), - pytest.param('--tpu_cores=1,', - {'tpu_cores': '1,'}), - pytest.param('--limit_train_batches=100', - {'limit_train_batches': 100}), - pytest.param('--limit_train_batches 0.8', - {'limit_train_batches': 0.8}), - pytest.param('--weights_summary=null', - {'weights_summary': None}), - pytest.param( - "", - { - # These parameters are marked as Optional[...] in Trainer.__init__, - # with None as default. They should not be changed by the argparse - # interface. - "min_steps": None, - "max_steps": None, - "log_gpu_memory": None, - "distributed_backend": None, - "weights_save_path": None, - "truncated_bptt_steps": None, - "resume_from_checkpoint": None, - "profiler": None, +@pytest.mark.parametrize( + ['cli_args', 'expected'], + [ + ('--auto_lr_find=True --auto_scale_batch_size=power', { + 'auto_lr_find': True, + 'auto_scale_batch_size': 'power' }), -]) + ( + '--auto_lr_find any_string --auto_scale_batch_size ON', { + 'auto_lr_find': 'any_string', + 'auto_scale_batch_size': True + } + ), + ('--auto_lr_find=Yes --auto_scale_batch_size=On', { + 'auto_lr_find': True, + 'auto_scale_batch_size': True + }), + ('--auto_lr_find Off --auto_scale_batch_size No', { + 'auto_lr_find': False, + 'auto_scale_batch_size': False + }), + ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', { + 'auto_lr_find': True, + 'auto_scale_batch_size': False + }), + ('--tpu_cores=8', { + 'tpu_cores': 8 + }), + ('--tpu_cores=1,', { + 'tpu_cores': '1,' + }), + ('--limit_train_batches=100', { + 'limit_train_batches': 100 + }), + ('--limit_train_batches 0.8', { + 'limit_train_batches': 0.8 + }), + ('--weights_summary=null', { + 'weights_summary': None + }), + ( + "", + { + # These parameters are marked as Optional[...] in Trainer.__init__, + # with None as default. They should not be changed by the argparse + # interface. + "min_steps": None, + "max_steps": None, + "log_gpu_memory": None, + "distributed_backend": None, + "weights_save_path": None, + "truncated_bptt_steps": None, + "resume_from_checkpoint": None, + "profiler": None, + } + ), + ] +) def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(' ') if cli_args else [] @@ -150,15 +157,18 @@ def test_parse_args_parsing(cli_args, expected): @pytest.mark.parametrize(['cli_args', 'expected', 'instantiate'], [ - pytest.param(['--gpus', '[0, 2]'], - {'gpus': [0, 2]}, - False), - pytest.param(['--tpu_cores=[1,3]'], - {'tpu_cores': [1, 3]}, - False), - pytest.param(['--accumulate_grad_batches={"5":3,"10":20}'], - {'accumulate_grad_batches': {5: 3, 10: 20}}, - True), + (['--gpus', '[0, 2]'], { + 'gpus': [0, 2] + }, False), + (['--tpu_cores=[1,3]'], { + 'tpu_cores': [1, 3] + }, False), + (['--accumulate_grad_batches={"5":3,"10":20}'], { + 'accumulate_grad_batches': { + 5: 3, + 10: 20 + } + }, True), ]) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" @@ -174,12 +184,13 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @pytest.mark.parametrize(['cli_args', 'expected_gpu'], [ - pytest.param('--gpus 1', [0]), - pytest.param('--gpus 0,', [0]), + ('--gpus 1', [0]), + ('--gpus 0,', [0]), + ('--gpus 0,1', [0, 1]), ]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") -def test_parse_args_parsing_gpus(cli_args, expected_gpu): +def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) cli_args = cli_args.split(' ') if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) @@ -190,13 +201,25 @@ def test_parse_args_parsing_gpus(cli_args, expected_gpu): assert trainer.data_parallel_device_ids == expected_gpu -@pytest.mark.skipif(sys.version_info < (3, 7), - reason="signature inspection while mocking is not working in Python < 3.7 despite autospec") +@pytest.mark.skipif( + sys.version_info < (3, 7), + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec" +) @pytest.mark.parametrize(['cli_args', 'extra_args'], [ - pytest.param({}, {}), - pytest.param({'logger': False}, {}), - pytest.param({'logger': False}, {'logger': True}), - pytest.param({'logger': False}, {'checkpoint_callback': True}), + ({}, {}), + ({ + 'logger': False + }, {}), + ({ + 'logger': False + }, { + 'logger': True + }), + ({ + 'logger': False + }, { + 'checkpoint_callback': True + }), ]) def test_init_from_argparse_args(cli_args, extra_args): unknown_args = dict(unknown_arg=0) @@ -214,9 +237,11 @@ def test_init_from_argparse_args(cli_args, extra_args): @pytest.mark.parametrize(['cli_args', 'expected_model', 'expected_trainer'], [ - pytest.param(['--model.model_param=7', '--trainer.limit_train_batches=100'], - {'model_param': 7}, - {'limit_train_batches': 100}), + (['--model.model_param=7', '--trainer.limit_train_batches=100'], { + 'model_param': 7 + }, { + 'limit_train_batches': 100 + }), ]) def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -242,6 +267,7 @@ def on_train_start(callback, trainer, model): monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) class TestModel(LightningModule): + def __init__(self, model_param: int): super().__init__() self.model_param = model_param @@ -254,48 +280,38 @@ def __init__(self, model_param: int): assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts -class TestLightningCLI(LightningCLI): - def before_fit(self): - for key in ['validation_step', - 'validation_step_end', - 'validation_epoch_end', - 'test_step', - 'test_step_end', - 'test_epoch_end']: - setattr(self.model, key, None) - - -def test_lightning_cli_mnist_args(tmpdir): +def test_lightning_cli_args(tmpdir): cli_args = [ - '--data.data_dir=' + str(tmpdir), - '--trainer.default_root_dir=' + str(tmpdir), + f'--data.data_dir={tmpdir}', + f'--trainer.default_root_dir={tmpdir}', '--trainer.max_epochs=1', '--trainer.weights_summary=null', ] - with mock.patch('sys.argv', ['trial.py'] + cli_args): - cli = TestLightningCLI(EvalModelTemplate, TrialMNISTDataModule) - assert cli.fit_result == 1 - config_path = os.path.join(str(tmpdir), 'lightning_logs', 'version_0', 'config.yaml') - assert os.path.isfile(config_path) - with open(config_path) as f: - config = yaml.safe_load(f.read()) - assert config['model'] == cli.config['model'] - assert config['data'] == cli.config['data'] - assert config['trainer'] == cli.config['trainer'] + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(BoringModel, BoringDataModule) + + assert cli.fit_result == 1 + config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert 'model' not in config and 'model' not in cli.config # no arguments to include + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer'] -def test_lightning_cli_mnist_config_and_subclass_mode(tmpdir): +def test_lightning_cli_config_and_subclass_mode(tmpdir): config = { 'model': { - 'class_path': 'tests.base.EvalModelTemplate', + 'class_path': 'tests.helpers.BoringModel', }, 'data': { - 'class_path': 'tests.base.datamodules.TrialMNISTDataModule', + 'class_path': 'tests.helpers.BoringDataModule', 'init_args': { - 'data_dir': str(tmpdir), + 'data_dir': str(tmpdir) }, }, 'trainer': { @@ -304,22 +320,18 @@ def test_lightning_cli_mnist_config_and_subclass_mode(tmpdir): 'weights_summary': None, }, } - config_path = os.path.join(str(tmpdir), 'config.yaml') + config_path = tmpdir / 'config.yaml' with open(config_path, 'w') as f: f.write(yaml.dump(config)) - with mock.patch('sys.argv', ['trial.py', '--config', config_path]): - cli = TestLightningCLI( - EvalModelTemplate, - TrialMNISTDataModule, - subclass_mode_model=True, - subclass_mode_data=True - ) - assert cli.fit_result == 1 - config_path = os.path.join(str(tmpdir), 'lightning_logs', 'version_0', 'config.yaml') - assert os.path.isfile(config_path) - with open(config_path) as f: - config = yaml.safe_load(f.read()) - assert config['model'] == cli.config['model'] - assert config['data'] == cli.config['data'] - assert config['trainer'] == cli.config['trainer'] + with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]): + cli = LightningCLI(BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True) + + assert cli.fit_result == 1 + config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert config['model'] == cli.config['model'] + assert config['data'] == cli.config['data'] + assert config['trainer'] == cli.config['trainer'] From f4e0fbb178ba2fdddb640e6b72e3ded14b65ab99 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 16 Feb 2021 01:37:07 +0100 Subject: [PATCH 20/35] Better formatting --- tests/trainer/test_lightning_cli.py | 134 +++++++++++----------------- 1 file changed, 50 insertions(+), 84 deletions(-) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index e05b90ac6c32c..91cb37d7f67c5 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -87,43 +87,19 @@ def _raise(): @pytest.mark.parametrize( ['cli_args', 'expected'], [ - ('--auto_lr_find=True --auto_scale_batch_size=power', { - 'auto_lr_find': True, - 'auto_scale_batch_size': 'power' - }), + ('--auto_lr_find=True --auto_scale_batch_size=power', {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), ( - '--auto_lr_find any_string --auto_scale_batch_size ON', { - 'auto_lr_find': 'any_string', - 'auto_scale_batch_size': True - } + '--auto_lr_find any_string --auto_scale_batch_size ON', + {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}, ), - ('--auto_lr_find=Yes --auto_scale_batch_size=On', { - 'auto_lr_find': True, - 'auto_scale_batch_size': True - }), - ('--auto_lr_find Off --auto_scale_batch_size No', { - 'auto_lr_find': False, - 'auto_scale_batch_size': False - }), - ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', { - 'auto_lr_find': True, - 'auto_scale_batch_size': False - }), - ('--tpu_cores=8', { - 'tpu_cores': 8 - }), - ('--tpu_cores=1,', { - 'tpu_cores': '1,' - }), - ('--limit_train_batches=100', { - 'limit_train_batches': 100 - }), - ('--limit_train_batches 0.8', { - 'limit_train_batches': 0.8 - }), - ('--weights_summary=null', { - 'weights_summary': None - }), + ('--auto_lr_find=Yes --auto_scale_batch_size=On', {'auto_lr_find': True, 'auto_scale_batch_size': True}), + ('--auto_lr_find Off --auto_scale_batch_size No', {'auto_lr_find': False, 'auto_scale_batch_size': False}), + ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', {'auto_lr_find': True, 'auto_scale_batch_size': False}), + ('--tpu_cores=8', {'tpu_cores': 8}), + ('--tpu_cores=1,', {'tpu_cores': '1,'}), + ('--limit_train_batches=100', {'limit_train_batches': 100}), + ('--limit_train_batches 0.8', {'limit_train_batches': 0.8}), + ('--weights_summary=null', {'weights_summary': None}), ( "", { @@ -138,9 +114,9 @@ def _raise(): "truncated_bptt_steps": None, "resume_from_checkpoint": None, "profiler": None, - } + }, ), - ] + ], ) def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" @@ -156,20 +132,14 @@ def test_parse_args_parsing(cli_args, expected): assert Trainer.from_argparse_args(args) -@pytest.mark.parametrize(['cli_args', 'expected', 'instantiate'], [ - (['--gpus', '[0, 2]'], { - 'gpus': [0, 2] - }, False), - (['--tpu_cores=[1,3]'], { - 'tpu_cores': [1, 3] - }, False), - (['--accumulate_grad_batches={"5":3,"10":20}'], { - 'accumulate_grad_batches': { - 5: 3, - 10: 20 - } - }, True), -]) +@pytest.mark.parametrize( + ['cli_args', 'expected', 'instantiate'], + [ + (['--gpus', '[0, 2]'], {'gpus': [0, 2]}, False), + (['--tpu_cores=[1,3]'], {'tpu_cores': [1, 3]}, False), + (['--accumulate_grad_batches={"5":3,"10":20}'], {'accumulate_grad_batches': {5: 3, 10: 20}}, True), + ], +) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) @@ -183,11 +153,14 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): assert Trainer.from_argparse_args(args) -@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [ - ('--gpus 1', [0]), - ('--gpus 0,', [0]), - ('--gpus 0,1', [0, 1]), -]) +@pytest.mark.parametrize( + ['cli_args', 'expected_gpu'], + [ + ('--gpus 1', [0]), + ('--gpus 0,', [0]), + ('--gpus 0,1', [0, 1]), + ], +) def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) @@ -203,24 +176,17 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): @pytest.mark.skipif( sys.version_info < (3, 7), - reason="signature inspection while mocking is not working in Python < 3.7 despite autospec" + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", +) +@pytest.mark.parametrize( + ['cli_args', 'extra_args'], + [ + ({}, {}), + ({'logger': False}, {}), + ({'logger': False}, {'logger': True}), + ({'logger': False}, {'checkpoint_callback': True}), + ], ) -@pytest.mark.parametrize(['cli_args', 'extra_args'], [ - ({}, {}), - ({ - 'logger': False - }, {}), - ({ - 'logger': False - }, { - 'logger': True - }), - ({ - 'logger': False - }, { - 'checkpoint_callback': True - }), -]) def test_init_from_argparse_args(cli_args, extra_args): unknown_args = dict(unknown_arg=0) @@ -236,13 +202,16 @@ def test_init_from_argparse_args(cli_args, extra_args): Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) -@pytest.mark.parametrize(['cli_args', 'expected_model', 'expected_trainer'], [ - (['--model.model_param=7', '--trainer.limit_train_batches=100'], { - 'model_param': 7 - }, { - 'limit_train_batches': 100 - }), -]) +@pytest.mark.parametrize( + ['cli_args', 'expected_model', 'expected_trainer'], + [ + ( + ['--model.model_param=7', '--trainer.limit_train_batches=100'], + {'model_param': 7}, + {'limit_train_batches': 100}, + ), + ], +) def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -267,7 +236,6 @@ def on_train_start(callback, trainer, model): monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) class TestModel(LightningModule): - def __init__(self, model_param: int): super().__init__() self.model_param = model_param @@ -310,9 +278,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): }, 'data': { 'class_path': 'tests.helpers.BoringDataModule', - 'init_args': { - 'data_dir': str(tmpdir) - }, + 'init_args': {'data_dir': str(tmpdir)}, }, 'trainer': { 'default_root_dir': str(tmpdir), From 9c4fc7fe0ee2230fe413d216aebdc02a693f4240 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 9 Mar 2021 20:33:35 +0100 Subject: [PATCH 21/35] - Change wrapping size to 120 in lightning_cli.rst. - Fix cli urls in docstrings. - Change LightningCLI's trainer_kwargs to trainer_defaults and fix behavior. --- docs/source/common/lightning_cli.rst | 159 +++++++++++---------------- pytorch_lightning/utilities/cli.py | 32 +++--- 2 files changed, 83 insertions(+), 108 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 51dff7746ce2c..84b5092cf352f 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -45,20 +45,16 @@ Lightning CLI and config files ------------------------------ -Another source of boilerplate code that Lightning can help to reduce is in the -implementation of training command line tools. Furthermore, it provides a -standardized way to configure trainings using a single file that includes -settings for :class:`~pytorch_lightning.trainer.trainer.Trainer` and user -extended :class:`~pytorch_lightning.core.lightning.LightningModule` and -:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The -full configuration is automatically saved in the log directory. This has the -benefit of greatly simplifying the reproducibility of experiments. - -The main requirement for user extended classes to be made configurable is that -all relevant init arguments must have type hints. This is not a very demanding -requirement since it is good practice to do anyway. As a bonus if the arguments -are described in the docstrings, then the help of the training tool will display -them. +Another source of boilerplate code that Lightning can help to reduce is in the implementation of training command line +tools. Furthermore, it provides a standardized way to configure trainings using a single file that includes settings for +:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +:class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The full configuration is automatically saved +in the log directory. This has the benefit of greatly simplifying the reproducibility of experiments. + +The main requirement for user extended classes to be made configurable is that all relevant init arguments must have +type hints. This is not a very demanding requirement since it is good practice to do anyway. As a bonus if the arguments +are described in the docstrings, then the help of the training tool will display them. ---------- @@ -66,10 +62,8 @@ them. LightningCLI ^^^^^^^^^^^^ -The case in which the user's -:class:`~pytorch_lightning.core.lightning.LightningModule` class implements all -required :code:`*_dataloader` methods, a :code:`trainer.py` tool can be as -simple as: +The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required +:code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as: .. testcode:: @@ -77,11 +71,9 @@ simple as: cli = LightningCLI(MyModel) -The help of the tool describing all configurable options and default values can -be shown by running :code:`python trainer.py --help`. Default options can be -changed by providing individual command line arguments. However, it is better -practice to create a configuration file and provide this to the tool. A way -to do this would be: +The help of the tool describing all configurable options and default values can be shown by running :code:`python +trainer.py --help`. Default options can be changed by providing individual command line arguments. However, it is better +practice to create a configuration file and provide this to the tool. A way to do this would be: .. code-block:: bash @@ -92,23 +84,21 @@ to do this would be: # Run training using created configuration python trainer.py --config config.yaml -The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` -class takes care of parsing command line and config file options, instantiating -the classes, setting up a callback to save the config in the log directory and -finally running :func:`trainer.fit`. The resulting object :code:`cli` can be -used for instance to get the result of fit, i.e., :code:`cli.fit_result`. +The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line +and config file options, instantiating the classes, setting up a callback to save the config in the log directory and +finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit, +i.e., :code:`cli.fit_result`. -After multiple trainings with different configurations, each run will have in -its respective log directory a :code:`config.yaml` file. This file can be used -for reference to know in detail all the settings that were used for each +After multiple trainings with different configurations, each run will have in its respective log directory a +:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each particular run, and also could be used to trivially reproduce a training, e.g.: .. code-block:: bash python trainer.py --config lightning_logs/version_7/config.yaml -If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` -class is required, the trainer tool just needs a small modification as follows: +If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class is required, the trainer tool just +needs a small modification as follows: .. testcode:: @@ -116,9 +106,8 @@ class is required, the trainer tool just needs a small modification as follows: cli = LightningCLI(MyModel, MyDataModule) -The start of a possible implementation of :class:`MyModel` including the -recommended argument descriptions in the docstring could be the one below. Note -that by using type hints and docstrings there is no need to duplicate this +The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the +docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this information to define its configurable arguments. .. code-block:: python @@ -172,8 +161,8 @@ With this model class, the help of the trainer tool would look as follows: Number of layers for each decoder block (type: List[int], default: [2, 4]) -The default configuration that option :code:`--print_config` gives is in yaml -format and for the example above would look as follows: +The default configuration that option :code:`--print_config` gives is in yaml format and for the example above would +look as follows: .. code-block:: bash @@ -190,24 +179,19 @@ format and for the example above would look as follows: amp_level: O2 ... -Note that there is a section for each class (model and trainer) including all -the init parameters of the class. This grouping is also used in the formatting -of the help shown previously. +Note that there is a section for each class (model and trainer) including all the init parameters of the class. This +grouping is also used in the formatting of the help shown previously. Trainer Callbacks and arguments with class type ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A very important argument of the -:class:`~pytorch_lightning.trainer.trainer.Trainer` class is the -:code:`callbacks`. In contrast to other more simple arguments which just require -numbers or strings, :code:`callbacks` expects a list of instances of subclasses -of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of -argument in a config file, each callback must be given as a dictionary including -a :code:`class_path` entry with an import path of the class, and optionally an -:code:`init_args` entry with arguments required to instantiate it. Therefore, a -simple configuration file example that defines a couple of callbacks is the -following: +A very important argument of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class is the :code:`callbacks`. In +contrast to other more simple arguments which just require numbers or strings, :code:`callbacks` expects a list of +instances of subclasses of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of argument in a config +file, each callback must be given as a dictionary including a :code:`class_path` entry with an import path of the class, +and optionally an :code:`init_args` entry with arguments required to instantiate it. Therefore, a simple configuration +file example that defines a couple of callbacks is the following: .. code-block:: yaml @@ -220,23 +204,19 @@ following: init_args: ... -Similar to the callbacks, any arguments in -:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended :class:`~pytorch_lightning.core.lightning.LightningModule` and -:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that -have as type hint a class can be configured the same way using -:code:`class_path` and :code:`init_args`. +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured +the same way using :code:`class_path` and :code:`init_args`. Multiple models and/or datasets ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` -works only for a single model and datamodule class. However, there are many -cases in which the objective is to easily be able to run many experiments for -multiple models and datasets. For these cases the tool can be configured such -that a model and/or a datamodule is specified by an import path and init -arguments. For example, with a tool implemented as: +In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and +datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for +multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is +specified by an import path and init arguments. For example, with a tool implemented as: .. testcode:: @@ -271,41 +251,33 @@ A possible config file could be as follows: patience: 5 ... -Only model classes that are a subclass of :code:`MyModelBaseClass` would be -allowed, and similarly only subclasses of :code:`MyDataModuleBaseClass`. +Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of +:code:`MyDataModuleBaseClass`. Customizing LightningCLI ^^^^^^^^^^^^^^^^^^^^^^^^ -The init parameters of the -:class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to -customize some things, namely: the description of the tool, enabling parsing of -environment variables and additional arguments to instantiate the trainer and -configuration parser. +The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some +things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to +instantiate the trainer and configuration parser. -Nevertheless the init arguments are not enough for many use cases. For this -reason the class is designed so that can be extended to customize different -parts of the command line tool. The argument parser class used by +Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be +extended to customize different parts of the command line tool. The argument parser class used by :class:`~pytorch_lightning.utilities.cli.LightningCLI` is -:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an -extension of python's argparse, thus adding arguments can be done using the -:func:`add_argument` method. In contrast to argparse it has additional methods -to add arguments, for example :func:`add_class_arguments` adds all arguments -from the init of a class, though requiring parameters to have type hints. For -more details about this please refer to the `respective documentation +:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus +adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to +add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring +parameters to have type hints. For more details about this please refer to the `respective documentation `_. The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the -:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` -method which can be implemented to include more arguments. After parsing, the -configuration is stored in the :code:`config` attribute of the class instance. -The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two -methods that can be used to run code before and after :code:`trainer.fit` is -executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and -:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic -example for these would be to send an email before and after the execution of -fit. The code would be something like: +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include +more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The +:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before +and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic example for these would be to send an email +before and after the execution of fit. The code would be something like: .. testcode:: @@ -330,12 +302,9 @@ fit. The code would be something like: cli = MyLightningCLI(MyModel) -Note that the config object :code:`self.config` is a dictionary whose keys are -global options or groups of options. It has the same structure as the yaml -format as described previously. This means for instance that the parameters used -for instantiating the trainer class can be found in -:code:`self.config['trainer']`. +Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It +has the same structure as the yaml format as described previously. This means for instance that the parameters used for +instantiating the trainer class can be found in :code:`self.config['trainer']`. -For more advanced use cases, other methods of the -:class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be extended. -For further information have a look at the corresponding API reference. +For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be +extended. For further information have a look at the corresponding API reference. diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d7da22d855cd4..ee01ec54ac742 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -34,7 +34,7 @@ def __init__(self, *args, parse_as_dict: bool = True, **kwargs): """Initialize argument parser that supports configuration file input For full details of accepted arguments see `ArgumentParser.__init__ - `_. + `_. """ if not _JSONARGPARSE_AVAILABLE: raise ModuleNotFoundError( @@ -86,7 +86,7 @@ def __init__( datamodule_class: Type[LightningDataModule] = None, save_config_callback: Type[SaveConfigCallback] = SaveConfigCallback, trainer_class: Type[Trainer] = Trainer, - trainer_kwargs: Dict[str, Any] = None, + trainer_defaults: Dict[str, Any] = None, description: str = 'pytorch-lightning trainer command line tool', env_prefix: str = 'PL', env_parse: bool = False, @@ -121,17 +121,17 @@ def __init__( datamodule_class: An optional LightningDataModule class. save_config_callback: A callback class to save the training config. trainer_class: An optional extension of the Trainer class. - trainer_kwargs: Additional arguments to instantiate Trainer. + trainer_defaults: Set to override Trainer defaults or add persistent callbacks. description: Description of the tool shown when running --help. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate LightningArgumentParser. subclass_mode_model: Whether model can be any `subclass - `_ of the - given class. + `_ + of the given class. subclass_mode_data: Whether datamodule can be any `subclass - `_ of the - given class. + `_ + of the given class. """ assert issubclass(trainer_class, Trainer) assert issubclass(model_class, LightningModule) @@ -141,7 +141,7 @@ def __init__( self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback self.trainer_class = trainer_class - self.trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs + self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs @@ -174,6 +174,8 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): def add_core_arguments_to_parser(self): """Adds arguments from the core classes to the parser""" self.parser.add_lightning_class_args(self.trainer_class, 'trainer') + trainer_defaults = {'trainer.'+k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} + self.parser.set_defaults(trainer_defaults) self.parser.add_lightning_class_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) @@ -219,12 +221,16 @@ def instantiate_model(self): def instantiate_trainer(self): """Instantiates the trainer using self.config_init['trainer']""" - self.trainer_kwargs.update(self.config_init['trainer']) - if self.trainer_kwargs.get('callbacks') is None: - self.trainer_kwargs['callbacks'] = [] + if self.config_init['trainer'].get('callbacks') is None: + self.config_init['trainer']['callbacks'] = [] + if 'callbacks' in self.trainer_defaults: + if isinstance(self.trainer_defaults['callbacks']): + self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) + else: + self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) if self.save_config_callback is not None: - self.trainer_kwargs['callbacks'].append(self.save_config_callback(self.parser, self.config)) - self.trainer = self.trainer_class(**self.trainer_kwargs) + self.config_init['trainer']['callbacks'].append(self.save_config_callback(self.parser, self.config)) + self.trainer = self.trainer_class(**self.config_init['trainer']) def prepare_fit_kwargs(self): """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" From ebc49449212fe269245cfc260de7651474f69c98 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 9 Mar 2021 20:38:45 +0100 Subject: [PATCH 22/35] - Fix missing space around operator. --- pytorch_lightning/utilities/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index ee01ec54ac742..6774a737f47d1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -174,7 +174,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): def add_core_arguments_to_parser(self): """Adds arguments from the core classes to the parser""" self.parser.add_lightning_class_args(self.trainer_class, 'trainer') - trainer_defaults = {'trainer.'+k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} + trainer_defaults = {'trainer.' + k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} self.parser.set_defaults(trainer_defaults) self.parser.add_lightning_class_args(self.model_class, 'model', subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: From aa0a12925a2780ce26403b512b042f16a66aefd1 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Tue, 9 Mar 2021 23:29:54 +0100 Subject: [PATCH 23/35] - Added trainer_defaults callbacks to unit tests --- pytorch_lightning/utilities/cli.py | 2 +- tests/trainer/test_lightning_cli.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6774a737f47d1..d73fe5722f55b 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -224,7 +224,7 @@ def instantiate_trainer(self): if self.config_init['trainer'].get('callbacks') is None: self.config_init['trainer']['callbacks'] = [] if 'callbacks' in self.trainer_defaults: - if isinstance(self.trainer_defaults['callbacks']): + if isinstance(self.trainer_defaults['callbacks'], list): self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) else: self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks']) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index 91cb37d7f67c5..d87439a73bc96 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -22,6 +22,7 @@ import yaml from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback from tests.helpers import BoringDataModule, BoringModel @@ -258,7 +259,11 @@ def test_lightning_cli_args(tmpdir): ] with mock.patch('sys.argv', ['any.py'] + cli_args): - cli = LightningCLI(BoringModel, BoringDataModule) + cli = LightningCLI( + BoringModel, + BoringDataModule, + trainer_defaults={'callbacks': [LearningRateMonitor()]} + ) assert cli.fit_result == 1 config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' @@ -291,7 +296,13 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): f.write(yaml.dump(config)) with mock.patch('sys.argv', ['any.py', '--config', str(config_path)]): - cli = LightningCLI(BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True) + cli = LightningCLI( + BoringModel, + BoringDataModule, + subclass_mode_model=True, + subclass_mode_data=True, + trainer_defaults={'callbacks': LearningRateMonitor()} + ) assert cli.fit_result == 1 config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' From 175026602d616ab0c4fc0287c4db46cd8b12e546 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 31 Mar 2021 08:46:41 +0200 Subject: [PATCH 24/35] Added unit test for LightningCLI with callbacks as argument --- tests/trainer/test_lightning_cli.py | 41 ++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index d87439a73bc96..41c340869a2f2 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import pickle import sys @@ -22,7 +23,7 @@ import yaml from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback from tests.helpers import BoringDataModule, BoringModel @@ -249,6 +250,44 @@ def __init__(self, model_param: int): assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts +def test_lightning_cli_args_callbacks(tmpdir, monkeypatch): + + callbacks = [ + { + 'class_path': 'pytorch_lightning.callbacks.LearningRateMonitor', + 'init_args': { + 'logging_interval': 'epoch', + 'log_momentum': True, + }, + }, + { + 'class_path': 'pytorch_lightning.callbacks.ModelCheckpoint', + 'init_args': { + 'monitor': 'NAME', + }, + }, + ] + + def fit(trainer, model): + callback = [c for c in trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == 'epoch' + assert callback[0].log_momentum == True + callback = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] + assert len(callback) == 1 + assert callback[0].monitor == 'NAME' + trainer.ran_asserts = True + + monkeypatch.setattr(Trainer, 'fit', fit) + + class TestModel(LightningModule): + pass + + with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']): + cli = LightningCLI(TestModel, trainer_class=Trainer) + assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts + + def test_lightning_cli_args(tmpdir): cli_args = [ From 08b3e5cd6ecd60b9327d66879af868498a74dc67 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 31 Mar 2021 09:12:37 +0200 Subject: [PATCH 25/35] Added to LightningCLI documentation the need for the 'all' extras require --- docs/source/common/lightning_cli.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 84b5092cf352f..5ccce47939e4e 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -62,6 +62,10 @@ are described in the docstrings, then the help of the training tool will display LightningCLI ^^^^^^^^^^^^ +The implementation of training command line tools is done via the :class:`~pytorch_lightning.utilities.cli.LightningCLI` +class. The minimal installation of pytorch-lightning does not include this support. To enable it either install +lightning with the :code:`all` extras require or install the package :code:`jsonargparse[signatures]`. + The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required :code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as: From 4e9320086f35025f8d32e33cb6600e13313979e0 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 31 Mar 2021 09:15:34 +0200 Subject: [PATCH 26/35] Fixed PEP8 issue in test_lightning_cli.py --- tests/trainer/test_lightning_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_lightning_cli.py index 41c340869a2f2..cd9785cad9e29 100644 --- a/tests/trainer/test_lightning_cli.py +++ b/tests/trainer/test_lightning_cli.py @@ -272,7 +272,7 @@ def fit(trainer, model): callback = [c for c in trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 assert callback[0].logging_interval == 'epoch' - assert callback[0].log_momentum == True + assert callback[0].log_momentum is True callback = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] assert len(callback) == 1 assert callback[0].monitor == 'NAME' From bac68b530b93a76f82563519a6a0b39092a3a2fc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 12:32:47 +0200 Subject: [PATCH 27/35] Add beta warnings --- docs/source/common/lightning_cli.rst | 2 ++ pytorch_lightning/utilities/cli.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 5ccce47939e4e..6d8ae6701fd40 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -56,6 +56,8 @@ The main requirement for user extended classes to be made configurable is that a type hints. This is not a very demanding requirement since it is good practice to do anyway. As a bonus if the arguments are described in the docstrings, then the help of the training tool will display them. +.. warning:: ``LightningCLI`` is in beta and subject to change. + ---------- diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index d73fe5722f55b..7b8b0a4f4f211 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -116,6 +116,8 @@ def __init__( $ nano config.yaml # modify the config as desired $ python trainer.py --cfg config.yaml + .. warning:: ``LightningCLI`` is in beta and subject to change. + Args: model_class: The LightningModule class to train on. datamodule_class: An optional LightningDataModule class. From ee0ff217a687355ecf61715601145dc7b912b7e5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 12:36:27 +0200 Subject: [PATCH 28/35] Rename test file --- tests/trainer/{test_lightning_cli.py => test_cli.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/trainer/{test_lightning_cli.py => test_cli.py} (100%) diff --git a/tests/trainer/test_lightning_cli.py b/tests/trainer/test_cli.py similarity index 100% rename from tests/trainer/test_lightning_cli.py rename to tests/trainer/test_cli.py From 7475b9f65bbb46a717ca48cc26c38e9809c18dc8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 12:52:14 +0200 Subject: [PATCH 29/35] Refactor callback test. Fix missing log_dir --- pytorch_lightning/utilities/cli.py | 6 +++++- tests/trainer/test_cli.py | 31 ++++++++++++++---------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7b8b0a4f4f211..2f33632440a39 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -74,7 +74,11 @@ def __init__(self, parser, config): self.config = config def on_train_start(self, trainer, pl_module): - config_path = os.path.join(trainer.logger.log_dir, 'config.yaml') + if hasattr(trainer, 'logger') and getattr(trainer.logger, 'log_dir', None) is not None: + config_dir = trainer.logger.log_dir + else: + config_dir = trainer.default_root_dir + config_path = os.path.join(config_dir, 'config.yaml') self.parser.save(self.config, config_path, skip_none=False) diff --git a/tests/trainer/test_cli.py b/tests/trainer/test_cli.py index cd9785cad9e29..be966d05aaeac 100644 --- a/tests/trainer/test_cli.py +++ b/tests/trainer/test_cli.py @@ -250,7 +250,7 @@ def __init__(self, model_param: int): assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts -def test_lightning_cli_args_callbacks(tmpdir, monkeypatch): +def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ { @@ -268,24 +268,21 @@ def test_lightning_cli_args_callbacks(tmpdir, monkeypatch): }, ] - def fit(trainer, model): - callback = [c for c in trainer.callbacks if isinstance(c, LearningRateMonitor)] - assert len(callback) == 1 - assert callback[0].logging_interval == 'epoch' - assert callback[0].log_momentum is True - callback = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] - assert len(callback) == 1 - assert callback[0].monitor == 'NAME' - trainer.ran_asserts = True - - monkeypatch.setattr(Trainer, 'fit', fit) - - class TestModel(LightningModule): - pass + class TestModel(BoringModel): + def on_fit_start(self): + callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == 'epoch' + assert callback[0].log_momentum is True + callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + assert len(callback) == 1 + assert callback[0].monitor == 'NAME' + self.trainer.ran_asserts = True with mock.patch('sys.argv', ['any.py', f'--trainer.callbacks={json.dumps(callbacks)}']): - cli = LightningCLI(TestModel, trainer_class=Trainer) - assert hasattr(cli.trainer, 'ran_asserts') and cli.trainer.ran_asserts + cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + + assert cli.trainer.ran_asserts def test_lightning_cli_args(tmpdir): From c5b0cec3013c300f562d106b7e025837ef4249b6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 12:58:01 +0200 Subject: [PATCH 30/35] Typing --- pytorch_lightning/utilities/cli.py | 47 +++++++++++++++++------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 2f33632440a39..7a3b6f93d529f 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from argparse import Namespace from typing import Any, Dict, Type, Union from pytorch_lightning.callbacks import Callback @@ -30,7 +31,7 @@ class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" - def __init__(self, *args, parse_as_dict: bool = True, **kwargs): + def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None: """Initialize argument parser that supports configuration file input For full details of accepted arguments see `ArgumentParser.__init__ @@ -51,7 +52,7 @@ def add_lightning_class_args( lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], nested_key: str, subclass_mode: bool = False - ): + ) -> None: """ Adds arguments from a lightning class to a nested key of the parser @@ -69,16 +70,22 @@ def add_lightning_class_args( class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts""" - def __init__(self, parser, config): + def __init__( + self, + parser: LightningArgumentParser, + config: Union[Namespace, Dict[str, Any]], + config_filename: str = 'config.yaml' + ) -> None: self.parser = parser self.config = config + self.config_filename = config_filename - def on_train_start(self, trainer, pl_module): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if hasattr(trainer, 'logger') and getattr(trainer.logger, 'log_dir', None) is not None: config_dir = trainer.logger.log_dir else: config_dir = trainer.default_root_dir - config_path = os.path.join(config_dir, 'config.yaml') + config_path = os.path.join(config_dir, self.config_filename) self.parser.save(self.config, config_path, skip_none=False) @@ -97,7 +104,7 @@ def __init__( parser_kwargs: Dict[str, Any] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False - ): + ) -> None: """ Implementation of a configurable command line tool for pytorch-lightning @@ -165,11 +172,11 @@ def __init__( self.fit() self.after_fit() - def init_parser(self): + def init_parser(self) -> None: """Method that instantiates the argument parser""" self.parser = LightningArgumentParser(**self.parser_kwargs) - def add_arguments_to_parser(self, parser: LightningArgumentParser): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Implement to add extra arguments to parser Args: @@ -177,7 +184,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): """ pass - def add_core_arguments_to_parser(self): + def add_core_arguments_to_parser(self) -> None: """Adds arguments from the core classes to the parser""" self.parser.add_lightning_class_args(self.trainer_class, 'trainer') trainer_defaults = {'trainer.' + k: v for k, v in self.trainer_defaults.items() if k != 'callbacks'} @@ -186,7 +193,7 @@ def add_core_arguments_to_parser(self): if self.datamodule_class is not None: self.parser.add_lightning_class_args(self.datamodule_class, 'data', subclass_mode=self.subclass_mode_data) - def before_parse_arguments(self, parser: LightningArgumentParser): + def before_parse_arguments(self, parser: LightningArgumentParser) -> None: """Implement to run some code before parsing arguments Args: @@ -194,22 +201,22 @@ def before_parse_arguments(self, parser: LightningArgumentParser): """ pass - def parse_arguments(self): + def parse_arguments(self) -> None: """Parses command line arguments and stores it in self.config""" self.config = self.parser.parse_args() - def before_instantiate_classes(self): + def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes""" pass - def instantiate_classes(self): + def instantiate_classes(self) -> None: """Instantiates the classes using settings from self.config""" self.config_init = self.parser.instantiate_subclasses(self.config) self.instantiate_datamodule() self.instantiate_model() self.instantiate_trainer() - def instantiate_datamodule(self): + def instantiate_datamodule(self) -> None: """Instantiates the datamodule using self.config_init['data'] if given""" if self.datamodule_class is None: self.datamodule = None @@ -218,14 +225,14 @@ def instantiate_datamodule(self): else: self.datamodule = self.datamodule_class(**self.config_init.get('data', {})) - def instantiate_model(self): + def instantiate_model(self) -> None: """Instantiates the model using self.config_init['model']""" if self.subclass_mode_model: self.model = self.config_init['model'] else: self.model = self.model_class(**self.config_init.get('model', {})) - def instantiate_trainer(self): + def instantiate_trainer(self) -> None: """Instantiates the trainer using self.config_init['trainer']""" if self.config_init['trainer'].get('callbacks') is None: self.config_init['trainer']['callbacks'] = [] @@ -238,20 +245,20 @@ def instantiate_trainer(self): self.config_init['trainer']['callbacks'].append(self.save_config_callback(self.parser, self.config)) self.trainer = self.trainer_class(**self.config_init['trainer']) - def prepare_fit_kwargs(self): + def prepare_fit_kwargs(self) -> None: """Prepares fit_kwargs including datamodule using self.config_init['data'] if given""" self.fit_kwargs = {'model': self.model} if self.datamodule is not None: self.fit_kwargs['datamodule'] = self.datamodule - def before_fit(self): + def before_fit(self) -> None: """Implement to run some code before fit is started""" pass - def fit(self): + def fit(self) -> None: """Runs fit of the instantiated trainer class and prepared fit keyword arguments""" self.fit_result = self.trainer.fit(**self.fit_kwargs) - def after_fit(self): + def after_fit(self) -> None: """Implement to run some code after fit has finished""" pass From ff43a8b340dacc0dbe1c82ab6974f4f24e045a77 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 13:16:23 +0200 Subject: [PATCH 31/35] Move test file to utilities. Use dict constructor which plays nicer with yapf --- tests/{trainer => utilities}/test_cli.py | 115 +++++++++-------------- 1 file changed, 47 insertions(+), 68 deletions(-) rename tests/{trainer => utilities}/test_cli.py (78%) diff --git a/tests/trainer/test_cli.py b/tests/utilities/test_cli.py similarity index 78% rename from tests/trainer/test_cli.py rename to tests/utilities/test_cli.py index be966d05aaeac..9b2c825c7aeb8 100644 --- a/tests/trainer/test_cli.py +++ b/tests/utilities/test_cli.py @@ -89,34 +89,34 @@ def _raise(): @pytest.mark.parametrize( ['cli_args', 'expected'], [ - ('--auto_lr_find=True --auto_scale_batch_size=power', {'auto_lr_find': True, 'auto_scale_batch_size': 'power'}), + ('--auto_lr_find=True --auto_scale_batch_size=power', dict(auto_lr_find=True, auto_scale_batch_size='power')), ( '--auto_lr_find any_string --auto_scale_batch_size ON', - {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}, + dict(auto_lr_find='any_string', auto_scale_batch_size=True), ), - ('--auto_lr_find=Yes --auto_scale_batch_size=On', {'auto_lr_find': True, 'auto_scale_batch_size': True}), - ('--auto_lr_find Off --auto_scale_batch_size No', {'auto_lr_find': False, 'auto_scale_batch_size': False}), - ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', {'auto_lr_find': True, 'auto_scale_batch_size': False}), - ('--tpu_cores=8', {'tpu_cores': 8}), - ('--tpu_cores=1,', {'tpu_cores': '1,'}), - ('--limit_train_batches=100', {'limit_train_batches': 100}), - ('--limit_train_batches 0.8', {'limit_train_batches': 0.8}), - ('--weights_summary=null', {'weights_summary': None}), + ('--auto_lr_find=Yes --auto_scale_batch_size=On', dict(auto_lr_find=True, auto_scale_batch_size=True)), + ('--auto_lr_find Off --auto_scale_batch_size No', dict(auto_lr_find=False, auto_scale_batch_size=False)), + ('--auto_lr_find TRUE --auto_scale_batch_size FALSE', dict(auto_lr_find=True, auto_scale_batch_size=False)), + ('--tpu_cores=8', dict(tpu_cores=8)), + ('--tpu_cores=1,', dict(tpu_cores='1,')), + ('--limit_train_batches=100', dict(limit_train_batches=100)), + ('--limit_train_batches 0.8', dict(limit_train_batches=0.8)), + ('--weights_summary=null', dict(weights_summary=None)), ( "", - { + dict( # These parameters are marked as Optional[...] in Trainer.__init__, # with None as default. They should not be changed by the argparse # interface. - "min_steps": None, - "max_steps": None, - "log_gpu_memory": None, - "distributed_backend": None, - "weights_save_path": None, - "truncated_bptt_steps": None, - "resume_from_checkpoint": None, - "profiler": None, - }, + min_steps=None, + max_steps=None, + log_gpu_memory=None, + distributed_backend=None, + weights_save_path=None, + truncated_bptt_steps=None, + resume_from_checkpoint=None, + profiler=None + ), ), ], ) @@ -137,9 +137,12 @@ def test_parse_args_parsing(cli_args, expected): @pytest.mark.parametrize( ['cli_args', 'expected', 'instantiate'], [ - (['--gpus', '[0, 2]'], {'gpus': [0, 2]}, False), - (['--tpu_cores=[1,3]'], {'tpu_cores': [1, 3]}, False), - (['--accumulate_grad_batches={"5":3,"10":20}'], {'accumulate_grad_batches': {5: 3, 10: 20}}, True), + (['--gpus', '[0, 2]'], dict(gpus=[0, 2]), False), + (['--tpu_cores=[1,3]'], dict(tpu_cores=[1, 3]), False), + (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={ + 5: 3, + 10: 20 + }), True), ], ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @@ -184,9 +187,9 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): ['cli_args', 'extra_args'], [ ({}, {}), - ({'logger': False}, {}), - ({'logger': False}, {'logger': True}), - ({'logger': False}, {'checkpoint_callback': True}), + (dict(logger=False), {}), + (dict(logger=False), dict(logger=True)), + (dict(logger=False), dict(checkpoint_callback=True)), ], ) def test_init_from_argparse_args(cli_args, extra_args): @@ -204,16 +207,11 @@ def test_init_from_argparse_args(cli_args, extra_args): Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) -@pytest.mark.parametrize( - ['cli_args', 'expected_model', 'expected_trainer'], - [ - ( - ['--model.model_param=7', '--trainer.limit_train_batches=100'], - {'model_param': 7}, - {'limit_train_batches': 100}, - ), - ], -) +@pytest.mark.parametrize(['cli_args', 'expected_model', 'expected_trainer'], [( + ['--model.model_param=7', '--trainer.limit_train_batches=100'], + dict(model_param=7), + dict(limit_train_batches=100), +)]) def test_lightning_cli(cli_args, expected_model, expected_trainer, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -238,6 +236,7 @@ def on_train_start(callback, trainer, model): monkeypatch.setattr(SaveConfigCallback, 'on_train_start', on_train_start) class TestModel(LightningModule): + def __init__(self, model_param: int): super().__init__() self.model_param = model_param @@ -253,22 +252,15 @@ def __init__(self, model_param: int): def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ - { - 'class_path': 'pytorch_lightning.callbacks.LearningRateMonitor', - 'init_args': { - 'logging_interval': 'epoch', - 'log_momentum': True, - }, - }, - { - 'class_path': 'pytorch_lightning.callbacks.ModelCheckpoint', - 'init_args': { - 'monitor': 'NAME', - }, - }, + dict( + class_path='pytorch_lightning.callbacks.LearningRateMonitor', + init_args=dict(logging_interval='epoch', log_momentum=True) + ), + dict(class_path='pytorch_lightning.callbacks.ModelCheckpoint', init_args=dict(monitor='NAME')), ] class TestModel(BoringModel): + def on_fit_start(self): callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 @@ -295,11 +287,7 @@ def test_lightning_cli_args(tmpdir): ] with mock.patch('sys.argv', ['any.py'] + cli_args): - cli = LightningCLI( - BoringModel, - BoringDataModule, - trainer_defaults={'callbacks': [LearningRateMonitor()]} - ) + cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]}) assert cli.fit_result == 1 config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml' @@ -313,20 +301,11 @@ def test_lightning_cli_args(tmpdir): def test_lightning_cli_config_and_subclass_mode(tmpdir): - config = { - 'model': { - 'class_path': 'tests.helpers.BoringModel', - }, - 'data': { - 'class_path': 'tests.helpers.BoringDataModule', - 'init_args': {'data_dir': str(tmpdir)}, - }, - 'trainer': { - 'default_root_dir': str(tmpdir), - 'max_epochs': 1, - 'weights_summary': None, - }, - } + config = dict( + model=dict(class_path='tests.helpers.BoringModel'), + data=dict(class_path='tests.helpers.BoringDataModule', init_args=dict(data_dir=str(tmpdir))), + trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None) + ) config_path = tmpdir / 'config.yaml' with open(config_path, 'w') as f: f.write(yaml.dump(config)) From c2f0b8687dc865726ff7a8f66631abce3736ae65 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 13:16:40 +0200 Subject: [PATCH 32/35] Refactor fn --- pytorch_lightning/trainer/properties.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 315e3c60c0557..2f01363086478 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -199,13 +199,7 @@ def slurm_job_id(self) -> Optional[int]: @classmethod def default_attributes(cls) -> dict: init_signature = inspect.signature(cls) - - args = {} - for param_name in init_signature.parameters: - value = init_signature.parameters[param_name].default - args[param_name] = value - - return args + return {k: v.default for k, v in init_signature.parameters.items()} @classmethod def get_deprecated_arg_names(cls) -> List: From 35f5d0d0e4b5adfef0784f5554fe1b8a58cb588b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 13:32:40 +0200 Subject: [PATCH 33/35] Use trainer.log_dir --- pytorch_lightning/utilities/cli.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7a3b6f93d529f..e0fb4d7829d80 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -81,11 +81,8 @@ def __init__( self.config_filename = config_filename def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - if hasattr(trainer, 'logger') and getattr(trainer.logger, 'log_dir', None) is not None: - config_dir = trainer.logger.log_dir - else: - config_dir = trainer.default_root_dir - config_path = os.path.join(config_dir, self.config_filename) + log_dir = trainer.log_dir or trainer.default_root_dir + config_path = os.path.join(log_dir, self.config_filename) self.parser.save(self.config, config_path, skip_none=False) From 956206119746b6707a180c20e451136d4b6ca94e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 14:00:22 +0200 Subject: [PATCH 34/35] Fix docs --- pytorch_lightning/utilities/cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index e0fb4d7829d80..7f91443ebd5c8 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -87,6 +87,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: class LightningCLI: + """Implementation of a configurable command line tool for pytorch-lightning""" def __init__( self, @@ -103,8 +104,6 @@ def __init__( subclass_mode_data: bool = False ) -> None: """ - Implementation of a configurable command line tool for pytorch-lightning - Receives as input pytorch-lightning classes, which are instantiated using a parsed configuration file and/or command line args and then runs trainer.fit. Parsing of configuration from environment variables can From 5604c04b02ea9cf5728667250ba043d0237d5ec8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 6 Apr 2021 14:03:09 +0200 Subject: [PATCH 35/35] Minor docs change --- pytorch_lightning/utilities/cli.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 7f91443ebd5c8..33c424829606a 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -107,11 +107,11 @@ def __init__( Receives as input pytorch-lightning classes, which are instantiated using a parsed configuration file and/or command line args and then runs trainer.fit. Parsing of configuration from environment variables can - be enabled by setting :code:`env_parse=True`. A full configuration yaml would - be parsed from :code:`PL_CONFIG` if set. Individual settings are so parsed from - variables named for example :code:`PL_TRAINER__MAX_EPOCHS`. + be enabled by setting ``env_parse=True``. A full configuration yaml would + be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from + variables named for example ``PL_TRAINER__MAX_EPOCHS``. - Example, first implement the trainer.py tool as:: + Example, first implement the ``trainer.py`` tool as:: from mymodels import MyModel from pytorch_lightning.utilities.cli import LightningCLI