Skip to content

Commit 3f7872d

Browse files
authored
[CLI] Shorthand notation to instantiate models (#9588)
1 parent 8f1c855 commit 3f7872d

File tree

3 files changed

+97
-19
lines changed

3 files changed

+97
-19
lines changed

docs/source/common/lightning_cli.rst

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -415,22 +415,59 @@ as described above:
415415
416416
$ python ... --trainer.callbacks=CustomCallback ...
417417
418-
This callback will be included in the generated config:
418+
.. note::
419419
420-
.. code-block:: yaml
420+
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
421+
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
422+
423+
.. code-block:: yaml
424+
425+
trainer:
426+
callbacks:
427+
- class_path: your_class_path.CustomCallback
428+
init_args:
429+
...
421430
422-
trainer:
423-
callbacks:
424-
- class_path: your_class_path.CustomCallback
425-
init_args:
426-
...
427431
428432
Multiple models and/or datasets
429433
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
430434
431435
In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and
432436
datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for
433-
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
437+
multiple models and datasets.
438+
439+
The model argument can be left unset if a model has been registered first, this is particularly interesting for library
440+
authors who want to provide their users a range of models to choose from:
441+
442+
.. code-block:: python
443+
444+
import flash.image
445+
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
446+
447+
448+
@MODEL_REGISTRY
449+
class MyModel(LightningModule):
450+
...
451+
452+
453+
# register all `LightningModule` subclasses from a package
454+
MODEL_REGISTRY.register_classes(flash.image, LightningModule)
455+
# print(MODEL_REGISTRY)
456+
# >>> Registered objects: ['MyModel', 'ImageClassifier', 'ObjectDetector', 'StyleTransfer', ...]
457+
458+
cli = LightningCLI()
459+
460+
.. code-block:: bash
461+
462+
$ python trainer.py fit --model=MyModel --model.feat_dim=64
463+
464+
.. note::
465+
466+
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
467+
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation described
468+
below.
469+
470+
Additionally, the tool can be configured such that a model and/or a datamodule is
434471
specified by an import path and init arguments. For example, with a tool implemented as:
435472
436473
.. code-block:: python
@@ -750,7 +787,7 @@ A corresponding example of the config file would be:
750787
751788
.. note::
752789
753-
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
790+
This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
754791
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
755792
756793
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
@@ -894,8 +931,8 @@ Notes related to reproducibility
894931
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
895932
896933
The topic of reproducibility is complex and it is impossible to guarantee reproducibility by just providing a class that
897-
people can use in unexpected ways. Nevertheless :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to give a
898-
framework and recommendations to make reproducibility simpler.
934+
people can use in unexpected ways. Nevertheless, the :class:`~pytorch_lightning.utilities.cli.LightningCLI` tries to
935+
give a framework and recommendations to make reproducibility simpler.
899936
900937
When an experiment is run, it is good practice to use a stable version of the source code, either being a released
901938
package or at least a commit of some version controlled repository. For each run of a CLI the config file is

pytorch_lightning/utilities/cli.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __str__(self) -> str:
8989
CALLBACK_REGISTRY = _Registry()
9090
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)
9191

92+
MODEL_REGISTRY = _Registry()
93+
9294

9395
class LightningArgumentParser(ArgumentParser):
9496
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
@@ -147,7 +149,7 @@ def add_lightning_class_args(
147149
if issubclass(lightning_class, Callback):
148150
self.callback_keys.append(nested_key)
149151
if subclass_mode:
150-
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
152+
return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=True)
151153
return self.add_class_arguments(
152154
lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer)
153155
)
@@ -385,7 +387,7 @@ class LightningCLI:
385387

386388
def __init__(
387389
self,
388-
model_class: Union[Type[LightningModule], Callable[..., LightningModule]],
390+
model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None,
389391
datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None,
390392
save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback,
391393
save_config_filename: str = "config.yaml",
@@ -413,8 +415,9 @@ def __init__(
413415
.. warning:: ``LightningCLI`` is in beta and subject to change.
414416
415417
Args:
416-
model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable
417-
which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called.
418+
model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a
419+
callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when
420+
called. If ``None``, you can pass a registered model with ``--model=MyModel``.
418421
datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
419422
callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
420423
called.
@@ -439,17 +442,20 @@ def __init__(
439442
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
440443
method. If set to ``False``, the trainer and model classes will be instantiated only.
441444
"""
442-
self.model_class = model_class
443445
self.datamodule_class = datamodule_class
444446
self.save_config_callback = save_config_callback
445447
self.save_config_filename = save_config_filename
446448
self.save_config_overwrite = save_config_overwrite
447449
self.trainer_class = trainer_class
448450
self.trainer_defaults = trainer_defaults or {}
449451
self.seed_everything_default = seed_everything_default
450-
self.subclass_mode_model = subclass_mode_model
451452
self.subclass_mode_data = subclass_mode_data
452453

454+
self.model_class = model_class
455+
# used to differentiate between the original value and the processed value
456+
self._model_class = model_class or LightningModule
457+
self.subclass_mode_model = (model_class is None) or subclass_mode_model
458+
453459
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
454460
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
455461
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
@@ -509,7 +515,12 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
509515
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
510516
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
511517
parser.set_defaults(trainer_defaults)
512-
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
518+
519+
parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
520+
if self.model_class is None and MODEL_REGISTRY:
521+
# did not pass a model and there are models registered
522+
parser.set_choices("model", MODEL_REGISTRY.classes)
523+
513524
if self.datamodule_class is not None:
514525
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)
515526

tests/utilities/test_cli.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from io import StringIO
2323
from typing import List, Optional, Union
2424
from unittest import mock
25+
from unittest.mock import ANY
2526

2627
import pytest
2728
import torch
@@ -39,6 +40,7 @@
3940
LightningArgumentParser,
4041
LightningCLI,
4142
LR_SCHEDULER_REGISTRY,
43+
MODEL_REGISTRY,
4244
OPTIMIZER_REGISTRY,
4345
SaveConfigCallback,
4446
)
@@ -888,6 +890,32 @@ def test_registries(tmpdir):
888890
assert isinstance(CustomCallback(), CustomCallback)
889891

890892

893+
@MODEL_REGISTRY
894+
class TestModel(BoringModel):
895+
def __init__(self, foo, bar=5):
896+
super().__init__()
897+
self.foo = foo
898+
self.bar = bar
899+
900+
901+
MODEL_REGISTRY(cls=BoringModel)
902+
903+
904+
def test_lightning_cli_model_choices():
905+
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
906+
"pytorch_lightning.Trainer._fit_impl"
907+
) as run:
908+
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
909+
assert isinstance(cli.model, BoringModel)
910+
run.assert_called_once_with(cli.model, ANY, ANY, ANY)
911+
912+
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]):
913+
cli = LightningCLI(run=False)
914+
assert isinstance(cli.model, TestModel)
915+
assert cli.model.foo == 123
916+
assert cli.model.bar == 5
917+
918+
891919
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
892920
def test_registries_resolution(use_class_path_callbacks):
893921
"""This test validates registries are used when simplified command line are being used."""
@@ -899,6 +927,7 @@ def test_registries_resolution(use_class_path_callbacks):
899927
"--trainer.callbacks=LearningRateMonitor",
900928
"--trainer.callbacks.logging_interval=epoch",
901929
"--trainer.callbacks.log_momentum=True",
930+
"--model=BoringModel",
902931
"--trainer.callbacks=ModelCheckpoint",
903932
"--trainer.callbacks.monitor=loss",
904933
"--lr_scheduler",
@@ -916,8 +945,9 @@ def test_registries_resolution(use_class_path_callbacks):
916945
extras = [Callback, Callback]
917946

918947
with mock.patch("sys.argv", ["any.py"] + cli_args):
919-
cli = LightningCLI(BoringModel, run=False)
948+
cli = LightningCLI(run=False)
920949

950+
assert isinstance(cli.model, BoringModel)
921951
optimizers, lr_scheduler = cli.model.configure_optimizers()
922952
assert isinstance(optimizers[0], torch.optim.Adam)
923953
assert optimizers[0].param_groups[0]["lr"] == 0.0001

0 commit comments

Comments
 (0)