Skip to content

Commit 01668bf

Browse files
ruroawaelchli
andauthored
Fix save_hyperparameters for multiple inheritance and mixins (#16369)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent fb12879 commit 01668bf

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Fixed
3131

32-
-
32+
- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369))
3333

3434

3535
## [1.9.0] - 2023-01-17

src/pytorch_lightning/utilities/parsing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ def _get_first_if_any(
137137
return n_self, n_args, n_kwargs
138138

139139

140-
def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
140+
def get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]:
141141
_, _, _, local_vars = inspect.getargvalues(frame)
142142
if "__class__" not in local_vars:
143-
return {}
143+
return None, {}
144144
cls = local_vars["__class__"]
145145
init_parameters = inspect.signature(cls.__init__).parameters
146146
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
@@ -152,7 +152,8 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
152152
if kwargs_var:
153153
local_args.update(local_args.get(kwargs_var, {}))
154154
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
155-
return local_args
155+
self_arg = local_vars.get(self_var, None)
156+
return self_arg, local_args
156157

157158

158159
def collect_init_args(
@@ -179,8 +180,8 @@ def collect_init_args(
179180
if not isinstance(frame.f_back, types.FrameType):
180181
return path_args
181182

182-
if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)):
183-
local_args = get_init_args(frame)
183+
local_self, local_args = get_init_args(frame)
184+
if "__class__" in local_vars and (not classes or isinstance(local_self, classes)):
184185
# recursive update
185186
path_args.append(local_args)
186187
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)

tests/tests_pytorch/models/test_hparams.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,24 @@ def __init__(self, *args, subclass_arg=1200, **kwargs):
296296
self.save_hyperparameters()
297297

298298

299+
class MixinForBoringModel:
300+
any_other_loss = torch.nn.CrossEntropyLoss()
301+
302+
def __init__(self, *args, subclass_arg=1200, **kwargs):
303+
super().__init__(*args, **kwargs)
304+
self.save_hyperparameters()
305+
306+
307+
class BoringModelWithMixin(MixinForBoringModel, CustomBoringModel):
308+
pass
309+
310+
311+
class BoringModelWithMixinAndInit(MixinForBoringModel, CustomBoringModel):
312+
def __init__(self, *args, **kwargs):
313+
super().__init__(*args, **kwargs)
314+
self.save_hyperparameters()
315+
316+
299317
class NonSavingSubClassBoringModel(CustomBoringModel):
300318
any_other_loss = torch.nn.CrossEntropyLoss()
301319

@@ -345,6 +363,8 @@ class DictConfSubClassBoringModel:
345363
AggSubClassBoringModel,
346364
UnconventionalArgsBoringModel,
347365
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
366+
BoringModelWithMixin,
367+
BoringModelWithMixinAndInit,
348368
],
349369
)
350370
def test_collect_init_arguments(tmpdir, cls):
@@ -360,7 +380,7 @@ def test_collect_init_arguments(tmpdir, cls):
360380
model = cls(batch_size=179, **extra_args)
361381
assert model.hparams.batch_size == 179
362382

363-
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel)):
383+
if isinstance(model, (SubClassBoringModel, NonSavingSubClassBoringModel, MixinForBoringModel)):
364384
assert model.hparams.subclass_arg == 1200
365385

366386
if isinstance(model, AggSubClassBoringModel):

tests/tests_pytorch/utilities/test_parsing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def get_init_args_wrapper(self):
255255
self.result = get_init_args(frame)
256256

257257
my_class = AutomaticArgsModel("test", anykw=32, otherkw=123)
258-
assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123}
258+
assert my_class.result == (my_class, {"anyarg": "test", "anykw": 32, "otherkw": 123})
259259

260260
my_class.get_init_args_wrapper()
261-
assert my_class.result == {}
261+
assert my_class.result == (None, {})
262262

263263

264264
def test_collect_init_args():

0 commit comments

Comments
 (0)