diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index e63844ae66576..ddc09dec2b6e3 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -197,6 +197,7 @@ Since the init parameters of the model have as a type hint a class, in the confi decoder: Instance of a module for decoding """ super().__init__() + self.save_hyperparameters() self.encoder = encoder self.decoder = decoder @@ -216,6 +217,13 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``. +.. tip:: + + By having ``self.save_hyperparameters()`` it becomes possible to load the model from a checkpoint. Simply do + ``ModelClass.load_from_checkpoint("path/to/checkpoint.ckpt")``. In the case of using ``subclass_mode_model=True``, + then load it like ``LightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")``. ``save_hyperparameters`` is + optional and can be safely removed if there is no need to load from a checkpoint. + Fixed optimizer and scheduler ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -279,6 +287,7 @@ An example of a model that uses two optimizers is the following: class MyModel(LightningModule): def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable): super().__init__() + self.save_hyperparameters() self.optimizer1 = optimizer1 self.optimizer2 = optimizer2 @@ -318,6 +327,7 @@ that uses dependency injection for an optimizer and a learning scheduler is: scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, ): super().__init__() + self.save_hyperparameters() self.optimizer = optimizer self.scheduler = scheduler diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index fddaece918b4a..39e3ff61d4e00 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,7 @@ matplotlib>3.1, <3.9.0 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.26.1, <4.28.0 +jsonargparse[signatures] >=4.27.5, <4.28.0 rich >=12.3.0, <13.6.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes ==0.41.0 # strict diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index aff7a54ef4db4..44d33dc05bf6c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468)) +- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105)) + - - @@ -64,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075)) - Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404)) + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 64879d3398cb5..30318c072bfd2 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import sys from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import torch +import yaml from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import _warn from torch.optim import Optimizer @@ -27,11 +29,12 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything +from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn -_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1") +_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser @@ -50,6 +53,8 @@ locals()["ArgumentParser"] = object locals()["Namespace"] = object +ModuleType = TypeVar("ModuleType") + class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: @@ -381,6 +386,7 @@ def __init__( self._set_seed() + self._add_instantiators() self.before_instantiate_classes() self.instantiate_classes() @@ -527,6 +533,22 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def _add_instantiators(self) -> None: + self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False)) + if "subcommand" in self.config: + self.config_dump = self.config_dump[self.config.subcommand] + + self.parser.add_instantiator( + _InstantiatorFn(cli=self, key="model"), + _get_module_type(self._model_class), + subclasses=self.subclass_mode_model, + ) + self.parser.add_instantiator( + _InstantiatorFn(cli=self, key="data"), + _get_module_type(self._datamodule_class), + subclasses=self.subclass_mode_data, + ) + def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" @@ -755,3 +777,33 @@ def _get_short_description(component: object) -> Optional[str]: return docstring.short_description except (ValueError, docstring_parser.ParseError) as ex: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") + + +def _get_module_type(value: Union[Callable, type]) -> type: + if callable(value) and not isinstance(value, type): + return inspect.signature(value).return_annotation + return value + + +class _InstantiatorFn: + def __init__(self, cli: LightningCLI, key: str) -> None: + self.cli = cli + self.key = key + + def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + with _given_hyperparameters_context( + hparams=self.cli.config_dump.get(self.key, {}), + instantiator="lightning.pytorch.cli.instantiate_module", + ): + return class_type(*args, **kwargs) + + +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: + parser = ArgumentParser(exit_on_error=False) + if "class_path" in config: + parser.add_subclass_arguments(class_type, "module") + else: + parser.add_class_arguments(class_type, "module") + cfg = parser.parse_object({"module": config}) + init = parser.instantiate_classes(cfg) + return init.module diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 5a4ea1782b34e..94ece0039d4f4 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,7 +15,9 @@ import inspect import types from argparse import Namespace -from typing import Any, List, MutableMapping, Optional, Sequence, Union +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -24,6 +26,20 @@ _ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) +_given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None) + + +@contextmanager +def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]: + hparams = hparams.copy() + hparams["_instantiator"] = instantiator + token = _given_hyperparameters.set(hparams) + try: + yield + finally: + _given_hyperparameters.reset(token) + + class HyperparametersMixin: __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] @@ -105,12 +121,13 @@ class ``__init__`` to be ignored """ self._log_hyperparams = logger + given_hparams = _given_hyperparameters.get() # the frame needs to be created in this file. - if not frame: + if given_hparams is None and not frame: current_frame = inspect.currentframe() if current_frame: frame = current_frame.f_back - save_hyperparameters(self, *args, ignore=ignore, frame=frame) + save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams) def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 78e449abd3e34..f8e9c8300337a 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -151,11 +151,18 @@ def _load_state( _cls_kwargs.update(cls_kwargs_loaded) _cls_kwargs.update(cls_kwargs_new) + instantiator = None + instantiator_path = _cls_kwargs.pop("_instantiator", None) + if instantiator_path is not None: + # import custom instantiator + module_path, name = instantiator_path.rsplit(".", 1) + instantiator = getattr(__import__(module_path, fromlist=[name]), name) + if not cls_spec.varkw: # filter kwargs according to class init unless it allows any argument via kwargs _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} - obj = cls(**_cls_kwargs) + obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs) if isinstance(obj, pl.LightningDataModule): if obj.__class__.__qualname__ in checkpoint: diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index ed14d06dc4c79..fd9209435468b 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -140,7 +140,11 @@ def collect_init_args( def save_hyperparameters( - obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None + obj: Any, + *args: Any, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None, + given_hparams: Optional[Dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -156,7 +160,9 @@ def save_hyperparameters( if not isinstance(frame, types.FrameType): raise AttributeError("There is no `frame` available while being required.") - if is_dataclass(obj): + if given_hparams is not None: + init_args = given_hparams + elif is_dataclass(obj): init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 71b2ff45cb328..95ed71680153a 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -833,6 +833,84 @@ def configure_optimizers(self): assert init[1]["lr_scheduler"].gamma == 0.3 +class TestModelSaveHparams(BoringModel): + def __init__( + self, + optimizer: OptimizerCallable = torch.optim.Adam, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05), + ): + super().__init__() + self.save_hyperparameters() + self.optimizer = optimizer + self.scheduler = scheduler + self.activation = activation + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters()) + scheduler = self.scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): + with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]): + cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False) + cli.trainer.fit(cli.model) + + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + expected = { + "_instantiator": "lightning.pytorch.cli.instantiate_module", + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, + } + assert hparams == expected + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + assert checkpoint_path.is_file() + ckpt = torch.load(checkpoint_path) + assert ckpt["hyper_parameters"] == expected + + model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path) + assert isinstance(model, TestModelSaveHparams) + assert isinstance(model.activation, torch.nn.LeakyReLU) + assert model.activation.negative_slope == 0.05 + optimizer, lr_scheduler = model.configure_optimizers().values() + assert isinstance(optimizer, torch.optim.Adam) + assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) + + +def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(cleandir): + with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1", "--model=TestModelSaveHparams"]): + cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True) + cli.trainer.fit(cli.model) + + expected = { + "_instantiator": "lightning.pytorch.cli.instantiate_module", + "class_path": f"{__name__}.TestModelSaveHparams", + "init_args": { + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, + }, + } + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + assert checkpoint_path.is_file() + ckpt = torch.load(checkpoint_path) + assert ckpt["hyper_parameters"] == expected + + model = LightningModule.load_from_checkpoint(checkpoint_path) + assert isinstance(model, TestModelSaveHparams) + assert isinstance(model.activation, torch.nn.LeakyReLU) + assert model.activation.negative_slope == 0.05 + optimizer, lr_scheduler = model.configure_optimizers().values() + assert isinstance(optimizer, torch.optim.Adam) + assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI):