From ba9de82097f87ab8773309e2c5abd4561c3c64b5 Mon Sep 17 00:00:00 2001 From: ruro Date: Sun, 15 Jan 2023 21:50:15 +0300 Subject: [PATCH 1/5] add tests for save_hyperparameters and mixins --- tests/tests_pytorch/models/test_hparams.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index b1c47eccf7eca..e238b69b7d77d 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -296,6 +296,24 @@ def __init__(self, *args, subclass_arg=1200, **kwargs): self.save_hyperparameters() +class MixinForBoringModel: + any_other_loss = torch.nn.CrossEntropyLoss() + + def __init__(self, *args, subclass_arg=1200, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +class BoringModelWithMixin(MixinForBoringModel, CustomBoringModel): + pass + + +class BoringModelWithMixinAndInit(MixinForBoringModel, CustomBoringModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + class NonSavingSubClassBoringModel(CustomBoringModel): any_other_loss = torch.nn.CrossEntropyLoss() @@ -345,6 +363,8 @@ class DictConfSubClassBoringModel: AggSubClassBoringModel, UnconventionalArgsBoringModel, pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)), + BoringModelWithMixin, + BoringModelWithMixinAndInit, ], ) def test_collect_init_arguments(tmpdir, cls): @@ -360,7 +380,7 @@ def test_collect_init_arguments(tmpdir, cls): model = cls(batch_size=179, **extra_args) assert model.hparams.batch_size == 179 - if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)): + if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel, MixinForBoringModel)): assert model.hparams.subclass_arg == 1200 if isinstance(model, AggSubClassBoringModel): From 8138268b6f1c836379739ebf9bf14d7dfa0ffaab Mon Sep 17 00:00:00 2001 From: ruro Date: Sun, 15 Jan 2023 21:50:38 +0300 Subject: [PATCH 2/5] fix collect_init_args to work for mixins --- src/pytorch_lightning/utilities/parsing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index ae969041da62a..ddf1c33e41d59 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -137,10 +137,10 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Dict[str, Any]: +def get_init_args(frame: types.FrameType) -> Tuple[Any, Dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: - return {} + return None, {} cls = local_vars["__class__"] init_parameters = inspect.signature(cls.__init__).parameters self_var, args_var, kwargs_var = parse_class_init_keys(cls) @@ -152,7 +152,8 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: if kwargs_var: local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} - return local_args + self_arg = local_vars.get(self_var, None) + return self_arg, local_args def collect_init_args( @@ -179,8 +180,8 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)): - local_args = get_init_args(frame) + local_self, local_args = get_init_args(frame) + if "__class__" in local_vars and (not classes or isinstance(local_self, classes)): # recursive update path_args.append(local_args) return collect_init_args(frame.f_back, path_args, inside=True, classes=classes) From 1eb16e33f13c63d71ae29d0ec6324d3eb6a1876d Mon Sep 17 00:00:00 2001 From: ruro Date: Sun, 15 Jan 2023 21:50:57 +0300 Subject: [PATCH 3/5] fix internal test for get_init_args --- tests/tests_pytorch/utilities/test_parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index fea3167db0d15..e9a5f9212497d 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -255,10 +255,10 @@ def get_init_args_wrapper(self): self.result = get_init_args(frame) my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) - assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123} + assert my_class.result == (my_class, {"anyarg": "test", "anykw": 32, "otherkw": 123}) my_class.get_init_args_wrapper() - assert my_class.result == {} + assert my_class.result == (None, {}) def test_collect_init_args(): From 37a5e4577c446c8beec3ce058c462d6262110845 Mon Sep 17 00:00:00 2001 From: ruro Date: Mon, 16 Jan 2023 19:25:48 +0300 Subject: [PATCH 4/5] improve get_init_args type hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/utilities/parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index ddf1c33e41d59..283b4a05a98ba 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -137,7 +137,7 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Tuple[Any, Dict[str, Any]]: +def get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} From 316e23c7047e5e19c5d7e6a9f8dae038b817006d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 18 Jan 2023 22:46:35 +0100 Subject: [PATCH 5/5] add changelog --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 8602dec1a854e..e62147a7e0041 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369)) ## [1.9.0] - 2023-01-17