Skip to content

Commit 904dde7

Browse files
s-rogpre-commit-ci[bot]Borda
authored
Fix inspection of unspecified args for container hparams (#9125)
* Update parsing.py * add todo (for single arg) * unblock non container single arg * init test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update CHANGELOG.md * pep8 line length * Update pytorch_lightning/utilities/parsing.py * remove dict namespace conversion * add omegaconf support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dict test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add omegaconf test * Update CHANGELOG.md * Update pytorch_lightning/utilities/parsing.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/utilities/parsing.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 73fca23 commit 904dde7

File tree

3 files changed

+76
-40
lines changed

3 files changed

+76
-40
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
294294
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))
295295

296296

297+
- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))
298+
299+
297300
- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))
298301

299302

pytorch_lightning/utilities/parsing.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
from typing_extensions import Literal
2323

2424
import pytorch_lightning as pl
25+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
2526
from pytorch_lightning.utilities.warnings import rank_zero_warn
2627

28+
if _OMEGACONF_AVAILABLE:
29+
from omegaconf.dictconfig import DictConfig
30+
2731

2832
def str_to_bool_or_str(val: str) -> Union[str, bool]:
2933
"""Possibly convert a string representation of truth to bool.
@@ -204,46 +208,57 @@ def save_hyperparameters(
204208
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
205209
) -> None:
206210
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
207-
211+
hparams_container_types = [Namespace, dict]
212+
if _OMEGACONF_AVAILABLE:
213+
hparams_container_types.append(DictConfig)
214+
# empty container
208215
if len(args) == 1 and not isinstance(args, str) and not args[0]:
209-
# args[0] is an empty container
210216
return
211-
212-
if not frame:
213-
current_frame = inspect.currentframe()
214-
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
215-
if current_frame:
216-
frame = current_frame.f_back
217-
if not isinstance(frame, types.FrameType):
218-
raise AttributeError("There is no `frame` available while being required.")
219-
220-
if is_dataclass(obj):
221-
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
222-
else:
223-
init_args = get_init_args(frame)
224-
assert init_args, "failed to inspect the obj init"
225-
226-
if ignore is not None:
227-
if isinstance(ignore, str):
228-
ignore = [ignore]
229-
if isinstance(ignore, (list, tuple)):
230-
ignore = [arg for arg in ignore if isinstance(arg, str)]
231-
init_args = {k: v for k, v in init_args.items() if k not in ignore}
232-
233-
if not args:
234-
# take all arguments
235-
hp = init_args
236-
obj._hparams_name = "kwargs" if hp else None
217+
# container
218+
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
219+
hp = args[0]
220+
obj._hparams_name = "hparams"
221+
obj._set_hparams(hp)
222+
obj._hparams_initial = copy.deepcopy(obj._hparams)
223+
return
224+
# non-container args parsing
237225
else:
238-
# take only listed arguments in `save_hparams`
239-
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
240-
if len(isx_non_str) == 1:
241-
hp = args[isx_non_str[0]]
242-
cand_names = [k for k, v in init_args.items() if v == hp]
243-
obj._hparams_name = cand_names[0] if cand_names else None
226+
if not frame:
227+
current_frame = inspect.currentframe()
228+
# inspect.currentframe() return type is Optional[types.FrameType]
229+
# current_frame.f_back called only if available
230+
if current_frame:
231+
frame = current_frame.f_back
232+
if not isinstance(frame, types.FrameType):
233+
raise AttributeError("There is no `frame` available while being required.")
234+
235+
if is_dataclass(obj):
236+
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
237+
else:
238+
init_args = get_init_args(frame)
239+
assert init_args, f"failed to inspect the obj init - {frame}"
240+
241+
if ignore is not None:
242+
if isinstance(ignore, str):
243+
ignore = [ignore]
244+
if isinstance(ignore, (list, tuple, set)):
245+
ignore = [arg for arg in ignore if isinstance(arg, str)]
246+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
247+
248+
if not args:
249+
# take all arguments
250+
hp = init_args
251+
obj._hparams_name = "kwargs" if hp else None
244252
else:
245-
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
246-
obj._hparams_name = "kwargs"
253+
# take only listed arguments in `save_hparams`
254+
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
255+
if len(isx_non_str) == 1:
256+
hp = args[isx_non_str[0]]
257+
cand_names = [k for k, v in init_args.items() if v == hp]
258+
obj._hparams_name = cand_names[0] if cand_names else None
259+
else:
260+
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
261+
obj._hparams_name = "kwargs"
247262

248263
# `hparams` are expected here
249264
if hp:

tests/core/test_datamodules.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pickle
15-
from argparse import ArgumentParser
15+
from argparse import ArgumentParser, Namespace
1616
from dataclasses import dataclass
1717
from typing import Any, Dict
1818
from unittest import mock
1919
from unittest.mock import call, PropertyMock
2020

2121
import pytest
2222
import torch
23+
from omegaconf import OmegaConf
2324

2425
from pytorch_lightning import LightningDataModule, Trainer
2526
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -528,16 +529,33 @@ def test_dm_init_from_datasets_dataloaders(iterable):
528529
)
529530

530531

531-
class DataModuleWithHparams(LightningDataModule):
532+
# all args
533+
class DataModuleWithHparams_0(LightningDataModule):
532534
def __init__(self, arg0, arg1, kwarg0=None):
533535
super().__init__()
534536
self.save_hyperparameters()
535537

536538

537-
def test_simple_hyperparameters_saving():
538-
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
539+
# single arg
540+
class DataModuleWithHparams_1(LightningDataModule):
541+
def __init__(self, arg0, *args, **kwargs):
542+
super().__init__()
543+
self.save_hyperparameters(arg0)
544+
545+
546+
def test_hyperparameters_saving():
547+
data = DataModuleWithHparams_0(10, "foo", kwarg0="bar")
539548
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})
540549

550+
data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar")
551+
assert data.hparams == AttributeDict({"hello": "world"})
552+
553+
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
554+
assert data.hparams == AttributeDict({"hello": "world"})
555+
556+
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
557+
assert data.hparams == OmegaConf.create({"hello": "world"})
558+
541559

542560
def test_define_as_dataclass():
543561
# makes sure that no functionality is broken and the user can still manually make

0 commit comments

Comments
 (0)