Skip to content

Commit 3b18da3

Browse files
awaelchlirohitgr7
andauthored
Fix saving hyperparameters in a composition where parent is not a LM or LDM (#14151)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 98ded45 commit 3b18da3

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7070
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
7171

7272

73+
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
74+
75+
76+
7377
## [1.7.1] - 2022-08-09
7478

7579
### Fixed

src/pytorch_lightning/utilities/parsing.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,18 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
162162

163163

164164
def collect_init_args(
165-
frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False
165+
frame: types.FrameType,
166+
path_args: List[Dict[str, Any]],
167+
inside: bool = False,
168+
classes: Tuple[Type, ...] = (),
166169
) -> List[Dict[str, Any]]:
167170
"""Recursively collects the arguments passed to the child constructors in the inheritance tree.
168171
169172
Args:
170173
frame: the current stack frame
171174
path_args: a list of dictionaries containing the constructor args in all parent classes
172175
inside: track if we are inside inheritance path, avoid terminating too soon
176+
classes: the classes in which to inspect the frames
173177
174178
Return:
175179
A list of dictionaries where each dictionary contains the arguments passed to the
@@ -181,13 +185,13 @@ def collect_init_args(
181185
if not isinstance(frame.f_back, types.FrameType):
182186
return path_args
183187

184-
if "__class__" in local_vars:
188+
if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)):
185189
local_args = get_init_args(frame)
186190
# recursive update
187191
path_args.append(local_args)
188-
return collect_init_args(frame.f_back, path_args, inside=True)
192+
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
189193
if not inside:
190-
return collect_init_args(frame.f_back, path_args, inside)
194+
return collect_init_args(frame.f_back, path_args, inside, classes=classes)
191195
return path_args
192196

193197

@@ -225,7 +229,10 @@ def save_hyperparameters(
225229
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
226230
else:
227231
init_args = {}
228-
for local_args in collect_init_args(frame, []):
232+
233+
from pytorch_lightning.core.mixins import HyperparametersMixin
234+
235+
for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
229236
init_args.update(local_args)
230237

231238
if ignore is None:

tests/tests_pytorch/models/test_hparams.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning import LightningModule, Trainer
3030
from pytorch_lightning.callbacks import ModelCheckpoint
3131
from pytorch_lightning.core.datamodule import LightningDataModule
32+
from pytorch_lightning.core.mixins import HyperparametersMixin
3233
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
3334
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
3435
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
@@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str:
399400
return raw_checkpoint_path
400401

401402

403+
@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule))
404+
def test_save_hyperparameters_under_composition(base_class):
405+
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
406+
collected."""
407+
408+
class ChildInComposition(base_class):
409+
def __init__(self, same_arg):
410+
super().__init__()
411+
self.save_hyperparameters()
412+
413+
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
414+
def __init__(self, same_arg="parent_default", other_arg="other"):
415+
self.child = ChildInComposition(same_arg="cocofruit")
416+
417+
parent = NotPLSubclass()
418+
assert parent.child.hparams == dict(same_arg="cocofruit")
419+
420+
402421
class LocalVariableModelSuperLast(BoringModel):
403422
"""This model has the super().__init__() call at the end."""
404423

0 commit comments

Comments
 (0)