From dd5901bfbaa90bbcb1316a33ffcec6dffa8e2797 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 18 Jul 2023 06:40:44 +0200 Subject: [PATCH 1/8] load_from_checkpoint support for LightningCLI when using dependency injection. --- requirements/pytorch/extra.txt | 7 ++- src/lightning/pytorch/CHANGELOG.md | 3 ++ src/lightning/pytorch/cli.py | 50 ++++++++++++++++++- .../pytorch/core/mixins/hparams_mixin.py | 19 ++++++- src/lightning/pytorch/core/saving.py | 3 +- src/lightning/pytorch/utilities/parsing.py | 10 +++- tests/tests_pytorch/test_cli.py | 47 +++++++++++++++++ 7 files changed, 128 insertions(+), 11 deletions(-) diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index fddaece918b4a..8da02e414442d 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,6 @@ matplotlib>3.1, <3.9.0 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.26.1, <4.28.0 -rich >=12.3.0, <13.6.0 -tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute -bitsandbytes ==0.41.0 # strict +jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator +rich >=12.3.0, <=13.5.2 +tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 0f50c3e842173..a701e2d5ebc9f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075)) - Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404)) +- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index f2ecc8d12b08f..167abb54faa69 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -15,9 +15,10 @@ import sys from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import torch +import yaml from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import _warn from torch.optim import Optimizer @@ -27,6 +28,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything +from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -197,6 +199,30 @@ def add_lr_scheduler_args( self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + def class_instantiator(self, class_type, *args, **kwargs): + for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items(): + if issubclass(class_type, base_type): + with given_hyperparameters_context(hparams): + return super().class_instantiator(class_type, *args, **kwargs) + return super().class_instantiator(class_type, *args, **kwargs) + + def instantiate_classes( + self, + cfg: Namespace, + instantiate_groups: bool = True, + hparam_context: Optional[Dict[str, type]] = None, + ) -> Namespace: + if hparam_context: + cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets! + self._hparam_context = {} + for key, base_type in hparam_context.items(): + hparams = cfg_dict.get(key, {}) + self._hparam_context[key] = (base_type, hparams) + init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups) + if hparam_context: + delattr(self, "_hparam_context") + return init + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -532,7 +558,13 @@ def before_instantiate_classes(self) -> None: def instantiate_classes(self) -> None: """Instantiates the classes and sets their attributes.""" - self.config_init = self.parser.instantiate_classes(self.config) + hparam_prefix = "" + if "subcommand" in self.config: + hparam_prefix = self.config["subcommand"] + "." + hparam_context = {hparam_prefix + "model": self._model_class} + if self.datamodule_class is not None: + hparam_context[hparam_prefix + "data"] = self._datamodule_class + self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context) self.datamodule = self._get(self.config_init, "data") self.model = self._get(self.config_init, "model") self._add_configure_optimizers_method_to_model(self.subcommand) @@ -755,3 +787,17 @@ def _get_short_description(component: object) -> Optional[str]: return docstring.short_description except (ValueError, docstring_parser.ParseError) as ex: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") + + +ModuleType = TypeVar("ModuleType") + + +def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: + parser = ArgumentParser(exit_on_error=False) + if "class_path" in config: + parser.add_subclass_arguments(class_type, "module") + else: + parser.add_class_arguments(class_type, "module") + cfg = parser.parse_object({"module": config}) + init = parser.instantiate_classes(cfg) + return init.module diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 5a4ea1782b34e..01063cef9f0b7 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,6 +15,8 @@ import inspect import types from argparse import Namespace +from contextlib import contextmanager +from contextvars import ContextVar from typing import Any, List, MutableMapping, Optional, Sequence, Union from lightning.fabric.utilities.data import AttributeDict @@ -24,6 +26,18 @@ _ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) +given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None) + + +@contextmanager +def given_hyperparameters_context(value): + token = given_hyperparameters.set(value) + try: + yield + finally: + given_hyperparameters.reset(token) + + class HyperparametersMixin: __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] @@ -105,12 +119,13 @@ class ``__init__`` to be ignored """ self._log_hyperparams = logger + given_hparams = given_hyperparameters.get() # the frame needs to be created in this file. - if not frame: + if given_hparams is None and not frame: current_frame = inspect.currentframe() if current_frame: frame = current_frame.f_back - save_hyperparameters(self, *args, ignore=ignore, frame=frame) + save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams) def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 78e449abd3e34..e9f288c965c8f 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -118,6 +118,7 @@ def _load_state( cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], strict: Optional[bool] = None, + instantiator=None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) @@ -155,7 +156,7 @@ def _load_state( # filter kwargs according to class init unless it allows any argument via kwargs _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} - obj = cls(**_cls_kwargs) + obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs) if isinstance(obj, pl.LightningDataModule): if obj.__class__.__qualname__ in checkpoint: diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index ed14d06dc4c79..fd9209435468b 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -140,7 +140,11 @@ def collect_init_args( def save_hyperparameters( - obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None + obj: Any, + *args: Any, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None, + given_hparams: Optional[Dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -156,7 +160,9 @@ def save_hyperparameters( if not isinstance(frame, types.FrameType): raise AttributeError("There is no `frame` available while being required.") - if is_dataclass(obj): + if given_hparams is not None: + init_args = given_hparams + elif is_dataclass(obj): init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fc3793a07b6ed..c03f156aa1791 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -835,6 +835,53 @@ def configure_optimizers(self): assert init[1]["lr_scheduler"].gamma == 0.3 +def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): + class TestModel(BoringModel): + def __init__( + self, + optimizer: OptimizerCallable = torch.optim.Adam, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05), + ): + super().__init__() + self.save_hyperparameters() + self.optimizer = optimizer + self.scheduler = scheduler + self.activation = activation + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters()) + scheduler = self.scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]): + cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False) + cli.trainer.fit(cli.model) + + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + expected = { + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, + } + assert hparams == expected + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + assert checkpoint_path.is_file() + ckpt = torch.load(checkpoint_path) + assert ckpt["hyper_parameters"] == expected + + model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module) + assert isinstance(model, TestModel) + assert isinstance(model.activation, torch.nn.LeakyReLU) + assert model.activation.negative_slope == 0.05 + optimizer, lr_scheduler = model.configure_optimizers().values() + assert isinstance(optimizer, torch.optim.Adam) + assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI): From df9e4b4c1c6352b4913c725658c933689d2ee289 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:26:54 +0200 Subject: [PATCH 2/8] Change implementation to use add_instantiator. --- requirements/pytorch/extra.txt | 7 +- src/lightning/pytorch/cli.py | 67 ++++++++++--------- .../pytorch/core/mixins/hparams_mixin.py | 4 +- src/lightning/pytorch/core/saving.py | 2 +- tests/tests_pytorch/test_cli.py | 1 + 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 8da02e414442d..fddaece918b4a 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,6 +5,7 @@ matplotlib>3.1, <3.9.0 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator -rich >=12.3.0, <=13.5.2 -tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute +jsonargparse[signatures] >=4.26.1, <4.28.0 +rich >=12.3.0, <13.6.0 +tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute +bitsandbytes ==0.41.0 # strict diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 167abb54faa69..7f9722ce3b547 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import sys from functools import partial, update_wrapper @@ -52,6 +53,8 @@ locals()["ArgumentParser"] = object locals()["Namespace"] = object +ModuleType = TypeVar("ModuleType") + class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: @@ -199,30 +202,6 @@ def add_lr_scheduler_args( self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) - def class_instantiator(self, class_type, *args, **kwargs): - for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items(): - if issubclass(class_type, base_type): - with given_hyperparameters_context(hparams): - return super().class_instantiator(class_type, *args, **kwargs) - return super().class_instantiator(class_type, *args, **kwargs) - - def instantiate_classes( - self, - cfg: Namespace, - instantiate_groups: bool = True, - hparam_context: Optional[Dict[str, type]] = None, - ) -> Namespace: - if hparam_context: - cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets! - self._hparam_context = {} - for key, base_type in hparam_context.items(): - hparams = cfg_dict.get(key, {}) - self._hparam_context[key] = (base_type, hparams) - init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups) - if hparam_context: - delattr(self, "_hparam_context") - return init - class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -407,6 +386,7 @@ def __init__( self._set_seed() + self._add_instantiators() self.before_instantiate_classes() self.instantiate_classes() @@ -553,18 +533,28 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def _add_instantiators(self) -> None: + self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False)) + if "subcommand" in self.config: + self.config_dump = self.config_dump[self.config.subcommand] + + self.parser.add_instantiator( + _InstantiatorFn(cli=self, key="model"), + _get_module_type(self._model_class), + subclasses=self.subclass_mode_model, + ) + self.parser.add_instantiator( + _InstantiatorFn(cli=self, key="data"), + _get_module_type(self._datamodule_class), + subclasses=self.subclass_mode_data, + ) + def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" def instantiate_classes(self) -> None: """Instantiates the classes and sets their attributes.""" - hparam_prefix = "" - if "subcommand" in self.config: - hparam_prefix = self.config["subcommand"] + "." - hparam_context = {hparam_prefix + "model": self._model_class} - if self.datamodule_class is not None: - hparam_context[hparam_prefix + "data"] = self._datamodule_class - self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context) + self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self._get(self.config_init, "data") self.model = self._get(self.config_init, "model") self._add_configure_optimizers_method_to_model(self.subcommand) @@ -789,7 +779,20 @@ def _get_short_description(component: object) -> Optional[str]: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") -ModuleType = TypeVar("ModuleType") +def _get_module_type(value: Union[Callable, type]) -> type: + if callable(value) and not isinstance(value, type): + return inspect.signature(value).return_annotation + return value + + +class _InstantiatorFn: + def __init__(self, cli: LightningCLI, key: str) -> None: + self.cli = cli + self.key = key + + def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})): + return class_type(*args, **kwargs) def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 01063cef9f0b7..1beaa8354634e 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -17,7 +17,7 @@ from argparse import Namespace from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, List, MutableMapping, Optional, Sequence, Union +from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -30,7 +30,7 @@ @contextmanager -def given_hyperparameters_context(value): +def given_hyperparameters_context(value: dict) -> Iterator[None]: token = given_hyperparameters.set(value) try: yield diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index e9f288c965c8f..2190b015901ca 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -118,7 +118,7 @@ def _load_state( cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], strict: Optional[bool] = None, - instantiator=None, + instantiator: Optional[Callable] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index c03f156aa1791..d6641fce47556 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -39,6 +39,7 @@ OptimizerCallable, SaveConfigCallback, instantiate_class, + instantiate_module, ) from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger From b5145e29175fdee9fbe0da7376b491869642e0e3 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:39:11 +0100 Subject: [PATCH 3/8] Require newer version of jsonargparse --- requirements/pytorch/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index fddaece918b4a..39e3ff61d4e00 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,7 @@ matplotlib>3.1, <3.9.0 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.26.1, <4.28.0 +jsonargparse[signatures] >=4.27.5, <4.28.0 rich >=12.3.0, <13.6.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes ==0.41.0 # strict From fa4074b9ae1c92a97d2dd133a963e6bb7952ab4c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 20 Feb 2024 02:42:00 +0100 Subject: [PATCH 4/8] Address review comments: protected names and removal of instantiator parameter. --- src/lightning/pytorch/cli.py | 7 +++++-- src/lightning/pytorch/core/mixins/hparams_mixin.py | 13 +++++++------ src/lightning/pytorch/core/saving.py | 8 +++++++- tests/tests_pytorch/test_cli.py | 4 ++-- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index fb12b11c70d74..3f6e72cc4225a 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -29,7 +29,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything -from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context +from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -791,7 +791,10 @@ def __init__(self, cli: LightningCLI, key: str) -> None: self.key = key def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: - with given_hyperparameters_context(self.cli.config_dump.get(self.key, {})): + with _given_hyperparameters_context( + hparams=self.cli.config_dump.get(self.key, {}), + instantiator="lightning.pytorch.cli.instantiate_module", + ): return class_type(*args, **kwargs) diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 1beaa8354634e..a125a00bf719a 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -26,16 +26,17 @@ _ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) -given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None) - +_given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None) @contextmanager -def given_hyperparameters_context(value: dict) -> Iterator[None]: - token = given_hyperparameters.set(value) +def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]: + hparams = hparams.copy() + hparams["_instantiator"] = instantiator + token = _given_hyperparameters.set(hparams) try: yield finally: - given_hyperparameters.reset(token) + _given_hyperparameters.reset(token) class HyperparametersMixin: @@ -119,7 +120,7 @@ class ``__init__`` to be ignored """ self._log_hyperparams = logger - given_hparams = given_hyperparameters.get() + given_hparams = _given_hyperparameters.get() # the frame needs to be created in this file. if given_hparams is None and not frame: current_frame = inspect.currentframe() diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 2190b015901ca..f8e9c8300337a 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -118,7 +118,6 @@ def _load_state( cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], strict: Optional[bool] = None, - instantiator: Optional[Callable] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) @@ -152,6 +151,13 @@ def _load_state( _cls_kwargs.update(cls_kwargs_loaded) _cls_kwargs.update(cls_kwargs_new) + instantiator = None + instantiator_path = _cls_kwargs.pop("_instantiator", None) + if instantiator_path is not None: + # import custom instantiator + module_path, name = instantiator_path.rsplit(".", 1) + instantiator = getattr(__import__(module_path, fromlist=[name]), name) + if not cls_spec.varkw: # filter kwargs according to class init unless it allows any argument via kwargs _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index d6641fce47556..25f593493411e 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -39,7 +39,6 @@ OptimizerCallable, SaveConfigCallback, instantiate_class, - instantiate_module, ) from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger @@ -863,6 +862,7 @@ def configure_optimizers(self): assert hparams_path.is_file() hparams = yaml.safe_load(hparams_path.read_text()) expected = { + "_instantiator": "lightning.pytorch.cli.instantiate_module", "optimizer": "torch.optim.Adam", "scheduler": "torch.optim.lr_scheduler.ConstantLR", "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, @@ -874,7 +874,7 @@ def configure_optimizers(self): ckpt = torch.load(checkpoint_path) assert ckpt["hyper_parameters"] == expected - model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module) + model = TestModel.load_from_checkpoint(checkpoint_path) assert isinstance(model, TestModel) assert isinstance(model.activation, torch.nn.LeakyReLU) assert model.activation.negative_slope == 0.05 From 9b9e5c626358eb4734536539d515fe51eb073a46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 01:43:34 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/mixins/hparams_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index a125a00bf719a..94ece0039d4f4 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -28,6 +28,7 @@ _given_hyperparameters: ContextVar = ContextVar("_given_hyperparameters", default=None) + @contextmanager def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator[None]: hparams = hparams.copy() From 8917d3da64c9c73fb090e4a7e32a022e240380c3 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:31:25 +0100 Subject: [PATCH 6/8] Mention load_from_checkpoint support in docs and add unit test for subclass mode. --- .../cli/lightning_cli_advanced_3.rst | 10 +++ tests/tests_pytorch/test_cli.py | 70 +++++++++++++------ 2 files changed, 60 insertions(+), 20 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index e63844ae66576..ddc09dec2b6e3 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -197,6 +197,7 @@ Since the init parameters of the model have as a type hint a class, in the confi decoder: Instance of a module for decoding """ super().__init__() + self.save_hyperparameters() self.encoder = encoder self.decoder = decoder @@ -216,6 +217,13 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``. +.. tip:: + + By having ``self.save_hyperparameters()`` it becomes possible to load the model from a checkpoint. Simply do + ``ModelClass.load_from_checkpoint("path/to/checkpoint.ckpt")``. In the case of using ``subclass_mode_model=True``, + then load it like ``LightningModule.load_from_checkpoint("path/to/checkpoint.ckpt")``. ``save_hyperparameters`` is + optional and can be safely removed if there is no need to load from a checkpoint. + Fixed optimizer and scheduler ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -279,6 +287,7 @@ An example of a model that uses two optimizers is the following: class MyModel(LightningModule): def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable): super().__init__() + self.save_hyperparameters() self.optimizer1 = optimizer1 self.optimizer2 = optimizer2 @@ -318,6 +327,7 @@ that uses dependency injection for an optimizer and a learning scheduler is: scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, ): super().__init__() + self.save_hyperparameters() self.optimizer = optimizer self.scheduler = scheduler diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 25f593493411e..7afe17503920e 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -835,27 +835,28 @@ def configure_optimizers(self): assert init[1]["lr_scheduler"].gamma == 0.3 -def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): - class TestModel(BoringModel): - def __init__( - self, - optimizer: OptimizerCallable = torch.optim.Adam, - scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, - activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05), - ): - super().__init__() - self.save_hyperparameters() - self.optimizer = optimizer - self.scheduler = scheduler - self.activation = activation +class TestModelSaveHparams(BoringModel): + def __init__( + self, + optimizer: OptimizerCallable = torch.optim.Adam, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05), + ): + super().__init__() + self.save_hyperparameters() + self.optimizer = optimizer + self.scheduler = scheduler + self.activation = activation + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters()) + scheduler = self.scheduler(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} - def configure_optimizers(self): - optimizer = self.optimizer(self.parameters()) - scheduler = self.scheduler(optimizer) - return {"optimizer": optimizer, "lr_scheduler": scheduler} +def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]): - cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False) + cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False) cli.trainer.fit(cli.model) hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" @@ -874,8 +875,37 @@ def configure_optimizers(self): ckpt = torch.load(checkpoint_path) assert ckpt["hyper_parameters"] == expected - model = TestModel.load_from_checkpoint(checkpoint_path) - assert isinstance(model, TestModel) + model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path) + assert isinstance(model, TestModelSaveHparams) + assert isinstance(model.activation, torch.nn.LeakyReLU) + assert model.activation.negative_slope == 0.05 + optimizer, lr_scheduler = model.configure_optimizers().values() + assert isinstance(optimizer, torch.optim.Adam) + assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) + + +def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(cleandir): + with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1", "--model=TestModelSaveHparams"]): + cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True) + cli.trainer.fit(cli.model) + + expected = { + "_instantiator": "lightning.pytorch.cli.instantiate_module", + "class_path": f"{__name__}.TestModelSaveHparams", + "init_args": { + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, + }, + } + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + assert checkpoint_path.is_file() + ckpt = torch.load(checkpoint_path) + assert ckpt["hyper_parameters"] == expected + + model = LightningModule.load_from_checkpoint(checkpoint_path) + assert isinstance(model, TestModelSaveHparams) assert isinstance(model.activation, torch.nn.LeakyReLU) assert model.activation.negative_slope == 0.05 optimizer, lr_scheduler = model.configure_optimizers().values() From 8fb083ef7e08ec85f3c026d1b4e2b84da37f6d39 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:58:19 +0100 Subject: [PATCH 7/8] Updated _JSONARGPARSE_SIGNATURES_AVAILABLE --- src/lightning/pytorch/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 3f6e72cc4225a..30318c072bfd2 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -34,7 +34,7 @@ from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn -_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.26.1") +_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.5") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser From a892efbd1deca5d27c41088a4a16cdfd46f4d7f2 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:56:47 +0100 Subject: [PATCH 8/8] Move up the changelog entry. --- src/lightning/pytorch/CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9bc857b7f1b93..3486f714f5987 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105)) - @@ -62,8 +62,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added shortcut name `strategy='deepspeed_stage_1_offload'` to the strategy registry ([#19075](https://github.com/Lightning-AI/lightning/pull/19075)) - Added support for non-strict state-dict loading in Trainer via the new `LightningModule.strict_loading = True | False` attribute ([#19404](https://github.com/Lightning-AI/lightning/pull/19404)) -- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105)) - ### Changed