From fb7f51a1e6f0ccd5cd0039e8475cede496dda827 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Feb 2023 13:10:23 +0100 Subject: [PATCH 1/2] Backwards compatibility for `get_init_args` --- src/lightning/pytorch/utilities/parsing.py | 10 ++++++++-- tests/tests_pytorch/utilities/test_parsing.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 40e721b90e206..7ee87b83e2b9d 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -80,7 +80,13 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: +def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover + """For backwards compatibility: #16369.""" + _, local_args = _get_init_args(frame) + return local_args + + +def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} @@ -123,7 +129,7 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - local_self, local_args = get_init_args(frame) + local_self, local_args = _get_init_args(frame) if "__class__" in local_vars and (not classes or isinstance(local_self, classes)): # recursive update path_args.append(local_args) diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index 1c83bbf31c669..bea31406f97f8 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -18,10 +18,10 @@ from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.parsing import ( + _get_init_args, AttributeDict, clean_namespace, collect_init_args, - get_init_args, is_picklable, lightning_getattr, lightning_hasattr, @@ -209,7 +209,7 @@ def __init__(self, anyarg, anykw=42, **kwargs): def get_init_args_wrapper(self): frame = inspect.currentframe().f_back - self.result = get_init_args(frame) + self.result = _get_init_args(frame) my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) assert my_class.result == (my_class, {"anyarg": "test", "anykw": 32, "otherkw": 123}) From 45a5df8d0df0d6a174a8c52187f19937a5cf5eaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Feb 2023 13:13:17 +0100 Subject: [PATCH 2/2] CHANGELOG --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f4fa30754b92f..438653ff8aeba 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -371,6 +371,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826)) +- Fixed backwards compatibility for `lightning.pytorch.utilities.parsing.get_init_args` ([#16851](https://github.com/Lightning-AI/lightning/pull/16851)) + + - Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))