Skip to content

Commit a59ee10

Browse files
awaelchlirohitgr7
andcommitted
Fix saving hyperparameters in a composition where parent is not a LM or LDM (#14151)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 552f496 commit a59ee10

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
1212

1313

14+
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
15+
16+
17+
1418
## [1.7.1] - 2022-08-09
1519

1620
### Fixed

src/pytorch_lightning/utilities/parsing.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,18 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
160160

161161

162162
def collect_init_args(
163-
frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False
163+
frame: types.FrameType,
164+
path_args: List[Dict[str, Any]],
165+
inside: bool = False,
166+
classes: Tuple[Type, ...] = (),
164167
) -> List[Dict[str, Any]]:
165168
"""Recursively collects the arguments passed to the child constructors in the inheritance tree.
166169
167170
Args:
168171
frame: the current stack frame
169172
path_args: a list of dictionaries containing the constructor args in all parent classes
170173
inside: track if we are inside inheritance path, avoid terminating too soon
174+
classes: the classes in which to inspect the frames
171175
172176
Return:
173177
A list of dictionaries where each dictionary contains the arguments passed to the
@@ -179,13 +183,13 @@ def collect_init_args(
179183
if not isinstance(frame.f_back, types.FrameType):
180184
return path_args
181185

182-
if "__class__" in local_vars:
186+
if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)):
183187
local_args = get_init_args(frame)
184188
# recursive update
185189
path_args.append(local_args)
186-
return collect_init_args(frame.f_back, path_args, inside=True)
190+
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
187191
if not inside:
188-
return collect_init_args(frame.f_back, path_args, inside)
192+
return collect_init_args(frame.f_back, path_args, inside, classes=classes)
189193
return path_args
190194

191195

@@ -223,7 +227,10 @@ def save_hyperparameters(
223227
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
224228
else:
225229
init_args = {}
226-
for local_args in collect_init_args(frame, []):
230+
231+
from pytorch_lightning.core.mixins import HyperparametersMixin
232+
233+
for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
227234
init_args.update(local_args)
228235

229236
if ignore is None:

tests/tests_pytorch/models/test_hparams.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning import LightningModule, Trainer
3030
from pytorch_lightning.callbacks import ModelCheckpoint
3131
from pytorch_lightning.core.datamodule import LightningDataModule
32+
from pytorch_lightning.core.mixins import HyperparametersMixin
3233
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
3334
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
3435
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
@@ -401,6 +402,24 @@ def _raw_checkpoint_path(trainer) -> str:
401402
return raw_checkpoint_path
402403

403404

405+
@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
406+
def test_save_hyperparameters_under_composition(base_class):
407+
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
408+
collected."""
409+
410+
class ChildInComposition(base_class):
411+
def __init__(self, same_arg):
412+
super().__init__()
413+
self.save_hyperparameters()
414+
415+
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
416+
def __init__(self, same_arg="parent_default", other_arg="other"):
417+
self.child = ChildInComposition(same_arg="cocofruit")
418+
419+
parent = NotPLSubclass()
420+
assert parent.child.hparams == dict(same_arg="cocofruit")
421+
422+
404423
class LocalVariableModelSuperLast(BoringModel):
405424
"""This model has the super().__init__() call at the end."""
406425

0 commit comments

Comments
 (0)