Skip to content

Commit 0a7388a

Browse files
committed
load_from_checkpoint support for LightningCLI when using dependency injection.
1 parent c9929e6 commit 0a7388a

File tree

7 files changed

+127
-8
lines changed

7 files changed

+127
-8
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
matplotlib>3.1, <3.7.3
66
omegaconf >=2.0.5, <2.4.0
77
hydra-core >=1.0.5, <1.4.0
8-
jsonargparse[signatures] >=4.18.0, <4.24.0 # strict
8+
jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/issue-170-class-instantiator
99
rich >=12.3.0, <=13.5.2
1010
tensorboardX >=2.2, <=2.6.2 # min version is set by torch.onnx missing attribute

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
106106
- Added support for saving sharded checkpoints with FSDP via `FSDPStrategy(state_dict_type="sharded")` ([#18364](https://github.com/Lightning-AI/lightning/pull/18364))
107107

108108

109+
- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
110+
111+
109112
### Changed
110113

111114
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/cli.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
import sys
1616
from functools import partial, update_wrapper
1717
from types import MethodType
18-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
18+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
1919

2020
import torch
21+
import yaml
2122
from lightning_utilities.core.imports import RequirementCache
2223
from lightning_utilities.core.rank_zero import _warn
2324
from torch.optim import Optimizer
@@ -26,6 +27,7 @@
2627
from lightning.fabric.utilities.cloud_io import get_filesystem
2728
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
2829
from lightning.pytorch import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
30+
from lightning.pytorch.core.mixins.hparams_mixin import given_hyperparameters_context
2931
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3032
from lightning.pytorch.utilities.model_helpers import is_overridden
3133
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
@@ -196,6 +198,30 @@ def add_lr_scheduler_args(
196198
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
197199
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
198200

201+
def class_instantiator(self, class_type, *args, **kwargs):
202+
for key, (base_type, hparams) in getattr(self, "_hparam_context", {}).items():
203+
if issubclass(class_type, base_type):
204+
with given_hyperparameters_context(hparams):
205+
return super().class_instantiator(class_type, *args, **kwargs)
206+
return super().class_instantiator(class_type, *args, **kwargs)
207+
208+
def instantiate_classes(
209+
self,
210+
cfg: Namespace,
211+
instantiate_groups: bool = True,
212+
hparam_context: Optional[Dict[str, type]] = None,
213+
) -> Namespace:
214+
if hparam_context:
215+
cfg_dict = yaml.safe_load(self.dump(cfg)) # TODO: do not remove link targets!
216+
self._hparam_context = {}
217+
for key, base_type in hparam_context.items():
218+
hparams = cfg_dict.get(key, {})
219+
self._hparam_context[key] = (base_type, hparams)
220+
init = super().instantiate_classes(cfg, instantiate_groups=instantiate_groups)
221+
if hparam_context:
222+
delattr(self, "_hparam_context")
223+
return init
224+
199225

200226
class SaveConfigCallback(Callback):
201227
"""Saves a LightningCLI config to the log_dir when training starts.
@@ -530,7 +556,13 @@ def before_instantiate_classes(self) -> None:
530556

531557
def instantiate_classes(self) -> None:
532558
"""Instantiates the classes and sets their attributes."""
533-
self.config_init = self.parser.instantiate_classes(self.config)
559+
hparam_prefix = ""
560+
if "subcommand" in self.config:
561+
hparam_prefix = self.config["subcommand"] + "."
562+
hparam_context = {hparam_prefix + "model": self._model_class}
563+
if self.datamodule_class is not None:
564+
hparam_context[hparam_prefix + "data"] = self._datamodule_class
565+
self.config_init = self.parser.instantiate_classes(self.config, hparam_context=hparam_context)
534566
self.datamodule = self._get(self.config_init, "data")
535567
self.model = self._get(self.config_init, "model")
536568
self._add_configure_optimizers_method_to_model(self.subcommand)
@@ -754,3 +786,17 @@ def _get_short_description(component: object) -> Optional[str]:
754786
return docstring.short_description
755787
except (ValueError, docstring_parser.ParseError) as ex:
756788
rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")
789+
790+
791+
ModuleType = TypeVar("ModuleType")
792+
793+
794+
def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
795+
parser = ArgumentParser(exit_on_error=False)
796+
if "class_path" in config:
797+
parser.add_subclass_arguments(class_type, "module")
798+
else:
799+
parser.add_class_arguments(class_type, "module")
800+
cfg = parser.parse_object({"module": config})
801+
init = parser.instantiate_classes(cfg)
802+
return init.module

src/lightning/pytorch/core/mixins/hparams_mixin.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import inspect
1616
import types
1717
from argparse import Namespace
18+
from contextlib import contextmanager
19+
from contextvars import ContextVar
1820
from typing import Any, List, MutableMapping, Optional, Sequence, Union
1921

2022
from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
@@ -23,6 +25,18 @@
2325
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
2426

2527

28+
given_hyperparameters: ContextVar = ContextVar("given_hyperparameters", default=None)
29+
30+
31+
@contextmanager
32+
def given_hyperparameters_context(value):
33+
token = given_hyperparameters.set(value)
34+
try:
35+
yield
36+
finally:
37+
given_hyperparameters.reset(token)
38+
39+
2640
class HyperparametersMixin:
2741
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]
2842

@@ -103,12 +117,13 @@ class ``__init__`` to be ignored
103117
"arg3": 3.14
104118
"""
105119
self._log_hyperparams = logger
120+
given_hparams = given_hyperparameters.get()
106121
# the frame needs to be created in this file.
107-
if not frame:
122+
if given_hparams is None and not frame:
108123
current_frame = inspect.currentframe()
109124
if current_frame:
110125
frame = current_frame.f_back
111-
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
126+
save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
112127

113128
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
114129
hp = self._to_hparams_dict(hp)

src/lightning/pytorch/core/saving.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _load_state(
123123
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
124124
checkpoint: Dict[str, Any],
125125
strict: Optional[bool] = None,
126+
instantiator=None,
126127
**cls_kwargs_new: Any,
127128
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
128129
cls_spec = inspect.getfullargspec(cls.__init__)
@@ -160,7 +161,7 @@ def _load_state(
160161
# filter kwargs according to class init unless it allows any argument via kwargs
161162
_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
162163

163-
obj = cls(**_cls_kwargs)
164+
obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
164165

165166
if isinstance(obj, pl.LightningModule):
166167
# give model a chance to load something

src/lightning/pytorch/utilities/parsing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def collect_init_args(
138138

139139

140140
def save_hyperparameters(
141-
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
141+
obj: Any,
142+
*args: Any,
143+
ignore: Optional[Union[Sequence[str], str]] = None,
144+
frame: Optional[types.FrameType] = None,
145+
given_hparams: Optional[Dict[str, Any]] = None,
142146
) -> None:
143147
"""See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`"""
144148

@@ -154,7 +158,9 @@ def save_hyperparameters(
154158
if not isinstance(frame, types.FrameType):
155159
raise AttributeError("There is no `frame` available while being required.")
156160

157-
if is_dataclass(obj):
161+
if given_hparams is not None:
162+
init_args = given_hparams
163+
elif is_dataclass(obj):
158164
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
159165
else:
160166
init_args = {}

tests/tests_pytorch/test_cli.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from lightning.pytorch.cli import (
4141
_JSONARGPARSE_SIGNATURES_AVAILABLE,
4242
instantiate_class,
43+
instantiate_module,
4344
LightningArgumentParser,
4445
LightningCLI,
4546
LRSchedulerCallable,
@@ -833,6 +834,53 @@ def configure_optimizers(self):
833834
assert init[1]["lr_scheduler"].gamma == 0.3
834835

835836

837+
def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
838+
class TestModel(BoringModel):
839+
def __init__(
840+
self,
841+
optimizer: OptimizerCallable = torch.optim.Adam,
842+
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
843+
activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05),
844+
):
845+
super().__init__()
846+
self.save_hyperparameters()
847+
self.optimizer = optimizer
848+
self.scheduler = scheduler
849+
self.activation = activation
850+
851+
def configure_optimizers(self):
852+
optimizer = self.optimizer(self.parameters())
853+
scheduler = self.scheduler(optimizer)
854+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
855+
856+
with mock.patch("sys.argv", ["any.py", "--trainer.max_epochs=1"]):
857+
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
858+
cli.trainer.fit(cli.model)
859+
860+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
861+
assert hparams_path.is_file()
862+
hparams = yaml.safe_load(hparams_path.read_text())
863+
expected = {
864+
"optimizer": "torch.optim.Adam",
865+
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
866+
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
867+
}
868+
assert hparams == expected
869+
870+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
871+
assert checkpoint_path.is_file()
872+
ckpt = torch.load(checkpoint_path)
873+
assert ckpt["hyper_parameters"] == expected
874+
875+
model = TestModel.load_from_checkpoint(checkpoint_path, instantiator=instantiate_module)
876+
assert isinstance(model, TestModel)
877+
assert isinstance(model.activation, torch.nn.LeakyReLU)
878+
assert model.activation.negative_slope == 0.05
879+
optimizer, lr_scheduler = model.configure_optimizers().values()
880+
assert isinstance(optimizer, torch.optim.Adam)
881+
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)
882+
883+
836884
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
837885
def test_lightning_cli_trainer_fn(fn):
838886
class TestCLI(LightningCLI):

0 commit comments

Comments
 (0)