diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 8602dec1a854e..e62147a7e0041 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369)) ## [1.9.0] - 2023-01-17 diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index ae969041da62a..283b4a05a98ba 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -137,10 +137,10 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Dict[str, Any]: +def get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: - return {} + return None, {} cls = local_vars["__class__"] init_parameters = inspect.signature(cls.__init__).parameters self_var, args_var, kwargs_var = parse_class_init_keys(cls) @@ -152,7 +152,8 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: if kwargs_var: local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} - return local_args + self_arg = local_vars.get(self_var, None) + return self_arg, local_args def collect_init_args( @@ -179,8 +180,8 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)): - local_args = get_init_args(frame) + local_self, local_args = get_init_args(frame) + if "__class__" in local_vars and (not classes or isinstance(local_self, classes)): # recursive update path_args.append(local_args) return collect_init_args(frame.f_back, path_args, inside=True, classes=classes) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index b1c47eccf7eca..e238b69b7d77d 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -296,6 +296,24 @@ def __init__(self, *args, subclass_arg=1200, **kwargs): self.save_hyperparameters() +class MixinForBoringModel: + any_other_loss = torch.nn.CrossEntropyLoss() + + def __init__(self, *args, subclass_arg=1200, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +class BoringModelWithMixin(MixinForBoringModel, CustomBoringModel): + pass + + +class BoringModelWithMixinAndInit(MixinForBoringModel, CustomBoringModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + class NonSavingSubClassBoringModel(CustomBoringModel): any_other_loss = torch.nn.CrossEntropyLoss() @@ -345,6 +363,8 @@ class DictConfSubClassBoringModel: AggSubClassBoringModel, UnconventionalArgsBoringModel, pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)), + BoringModelWithMixin, + BoringModelWithMixinAndInit, ], ) def test_collect_init_arguments(tmpdir, cls): @@ -360,7 +380,7 @@ def test_collect_init_arguments(tmpdir, cls): model = cls(batch_size=179, **extra_args) assert model.hparams.batch_size == 179 - if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)): + if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel, MixinForBoringModel)): assert model.hparams.subclass_arg == 1200 if isinstance(model, AggSubClassBoringModel): diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index fea3167db0d15..e9a5f9212497d 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -255,10 +255,10 @@ def get_init_args_wrapper(self): self.result = get_init_args(frame) my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) - assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123} + assert my_class.result == (my_class, {"anyarg": "test", "anykw": 32, "otherkw": 123}) my_class.get_init_args_wrapper() - assert my_class.result == {} + assert my_class.result == (None, {}) def test_collect_init_args():