Skip to content

Commit 185c4fd

Browse files
justusschockpre-commit-ci[bot]Borda
authored andcommitted
Fix inspection of unspecified args for container hparams (#9125)
* Update parsing.py * add todo (for single arg) * unblock non container single arg * init test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update CHANGELOG.md * pep8 line length * Update pytorch_lightning/utilities/parsing.py * remove dict namespace conversion * add omegaconf support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add dict test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add omegaconf test * Update CHANGELOG.md * Update pytorch_lightning/utilities/parsing.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/utilities/parsing.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 96541cf commit 185c4fd

File tree

3 files changed

+149
-40
lines changed

3 files changed

+149
-40
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))
3838

3939

40+
- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))
41+
4042
## [1.4.5] - 2021-08-31
4143

4244
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/utilities/parsing.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
from typing_extensions import Literal
2323

2424
import pytorch_lightning as pl
25+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
2526
from pytorch_lightning.utilities.warnings import rank_zero_warn
2627

28+
if _OMEGACONF_AVAILABLE:
29+
from omegaconf.dictconfig import DictConfig
30+
2731

2832
def str_to_bool_or_str(val: str) -> Union[str, bool]:
2933
"""Possibly convert a string representation of truth to bool.
@@ -204,46 +208,57 @@ def save_hyperparameters(
204208
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
205209
) -> None:
206210
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
207-
211+
hparams_container_types = [Namespace, dict]
212+
if _OMEGACONF_AVAILABLE:
213+
hparams_container_types.append(DictConfig)
214+
# empty container
208215
if len(args) == 1 and not isinstance(args, str) and not args[0]:
209-
# args[0] is an empty container
210216
return
211-
212-
if not frame:
213-
current_frame = inspect.currentframe()
214-
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
215-
if current_frame:
216-
frame = current_frame.f_back
217-
if not isinstance(frame, types.FrameType):
218-
raise AttributeError("There is no `frame` available while being required.")
219-
220-
if is_dataclass(obj):
221-
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
222-
else:
223-
init_args = get_init_args(frame)
224-
assert init_args, "failed to inspect the obj init"
225-
226-
if ignore is not None:
227-
if isinstance(ignore, str):
228-
ignore = [ignore]
229-
if isinstance(ignore, (list, tuple)):
230-
ignore = [arg for arg in ignore if isinstance(arg, str)]
231-
init_args = {k: v for k, v in init_args.items() if k not in ignore}
232-
233-
if not args:
234-
# take all arguments
235-
hp = init_args
236-
obj._hparams_name = "kwargs" if hp else None
217+
# container
218+
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
219+
hp = args[0]
220+
obj._hparams_name = "hparams"
221+
obj._set_hparams(hp)
222+
obj._hparams_initial = copy.deepcopy(obj._hparams)
223+
return
224+
# non-container args parsing
237225
else:
238-
# take only listed arguments in `save_hparams`
239-
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
240-
if len(isx_non_str) == 1:
241-
hp = args[isx_non_str[0]]
242-
cand_names = [k for k, v in init_args.items() if v == hp]
243-
obj._hparams_name = cand_names[0] if cand_names else None
226+
if not frame:
227+
current_frame = inspect.currentframe()
228+
# inspect.currentframe() return type is Optional[types.FrameType]
229+
# current_frame.f_back called only if available
230+
if current_frame:
231+
frame = current_frame.f_back
232+
if not isinstance(frame, types.FrameType):
233+
raise AttributeError("There is no `frame` available while being required.")
234+
235+
if is_dataclass(obj):
236+
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
237+
else:
238+
init_args = get_init_args(frame)
239+
assert init_args, f"failed to inspect the obj init - {frame}"
240+
241+
if ignore is not None:
242+
if isinstance(ignore, str):
243+
ignore = [ignore]
244+
if isinstance(ignore, (list, tuple, set)):
245+
ignore = [arg for arg in ignore if isinstance(arg, str)]
246+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
247+
248+
if not args:
249+
# take all arguments
250+
hp = init_args
251+
obj._hparams_name = "kwargs" if hp else None
244252
else:
245-
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
246-
obj._hparams_name = "kwargs"
253+
# take only listed arguments in `save_hparams`
254+
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
255+
if len(isx_non_str) == 1:
256+
hp = args[isx_non_str[0]]
257+
cand_names = [k for k, v in init_args.items() if v == hp]
258+
obj._hparams_name = cand_names[0] if cand_names else None
259+
else:
260+
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
261+
obj._hparams_name = "kwargs"
247262

248263
# `hparams` are expected here
249264
if hp:

tests/core/test_datamodules.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pickle
15-
from argparse import ArgumentParser
15+
from argparse import ArgumentParser, Namespace
16+
from dataclasses import dataclass
1617
from typing import Any, Dict
1718
from unittest import mock
1819
from unittest.mock import call, PropertyMock
1920

2021
import pytest
2122
import torch
23+
from omegaconf import OmegaConf
2224

2325
from pytorch_lightning import LightningDataModule, Trainer
2426
from pytorch_lightning.callbacks import ModelCheckpoint
2527
from pytorch_lightning.utilities import AttributeDict
28+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2629
from pytorch_lightning.utilities.model_helpers import is_overridden
2730
from tests.helpers import BoringDataModule, BoringModel
2831
from tests.helpers.datamodules import ClassifDataModule
@@ -532,12 +535,101 @@ def test_dm_init_from_datasets_dataloaders(iterable):
532535
)
533536

534537

535-
class DataModuleWithHparams(LightningDataModule):
538+
# all args
539+
class DataModuleWithHparams_0(LightningDataModule):
536540
def __init__(self, arg0, arg1, kwarg0=None):
537541
super().__init__()
538542
self.save_hyperparameters()
539543

540544

541-
def test_simple_hyperparameters_saving():
542-
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
545+
# single arg
546+
class DataModuleWithHparams_1(LightningDataModule):
547+
def __init__(self, arg0, *args, **kwargs):
548+
super().__init__()
549+
self.save_hyperparameters(arg0)
550+
551+
552+
def test_hyperparameters_saving():
553+
data = DataModuleWithHparams_0(10, "foo", kwarg0="bar")
543554
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})
555+
556+
data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar")
557+
assert data.hparams == AttributeDict({"hello": "world"})
558+
559+
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
560+
assert data.hparams == AttributeDict({"hello": "world"})
561+
562+
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
563+
assert data.hparams == OmegaConf.create({"hello": "world"})
564+
565+
566+
def test_define_as_dataclass():
567+
# makes sure that no functionality is broken and the user can still manually make
568+
# super().__init__ call with parameters
569+
# also tests all the dataclass features that can be enabled without breaking anything
570+
@dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False)
571+
class BoringDataModule1(LightningDataModule):
572+
batch_size: int
573+
dims: int = 2
574+
575+
def __post_init__(self):
576+
super().__init__(dims=self.dims)
577+
578+
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
579+
# __repr__, __eq__, __lt__, __le__, etc.
580+
assert BoringDataModule1(batch_size=64).dims == 2
581+
assert BoringDataModule1(batch_size=32)
582+
assert hasattr(BoringDataModule1, "__repr__")
583+
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
584+
585+
# asserts inherent calling of super().__init__ in case user doesn't make the call
586+
@dataclass
587+
class BoringDataModule2(LightningDataModule):
588+
batch_size: int
589+
590+
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
591+
# __init__, __repr__, __eq__, __lt__, __le__, etc.
592+
assert BoringDataModule2(batch_size=32)
593+
assert hasattr(BoringDataModule2, "__repr__")
594+
assert BoringDataModule2(batch_size=32).prepare_data() is None
595+
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
596+
597+
# checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule
598+
@dataclass
599+
class BoringModuleBase1(LightningDataModule):
600+
num_features: int
601+
602+
class BoringModuleBase2(LightningDataModule):
603+
def __init__(self, num_features: int):
604+
self.num_features = num_features
605+
606+
@dataclass
607+
class BoringModuleDerived1(BoringModuleBase1):
608+
...
609+
610+
class BoringModuleDerived2(BoringModuleBase1):
611+
def __init__(self):
612+
...
613+
614+
@dataclass
615+
class BoringModuleDerived3(BoringModuleBase2):
616+
...
617+
618+
class BoringModuleDerived4(BoringModuleBase2):
619+
def __init__(self):
620+
...
621+
622+
assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data")
623+
assert hasattr(BoringModuleDerived2(), "_has_prepared_data")
624+
assert hasattr(BoringModuleDerived3(), "_has_prepared_data")
625+
assert hasattr(BoringModuleDerived4(), "_has_prepared_data")
626+
627+
628+
def test_inconsistent_prepare_data_per_node(tmpdir):
629+
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
630+
model = BoringModel()
631+
dm = BoringDataModule()
632+
trainer = Trainer(prepare_data_per_node=False)
633+
trainer.model = model
634+
trainer.datamodule = dm
635+
trainer.data_connector.prepare_data()

0 commit comments

Comments
 (0)