diff --git a/CHANGELOG.md b/CHANGELOG.md index bd7bd5338de9d..e7f62f399049b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -286,6 +286,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267)) +- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125)) + + - Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308)) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index afa5c98f8b46c..8bb055da87781 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -22,8 +22,12 @@ from typing_extensions import Literal import pytorch_lightning as pl +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.warnings import rank_zero_warn +if _OMEGACONF_AVAILABLE: + from omegaconf.dictconfig import DictConfig + def str_to_bool_or_str(val: str) -> Union[str, bool]: """Possibly convert a string representation of truth to bool. @@ -204,46 +208,57 @@ def save_hyperparameters( obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" - + hparams_container_types = [Namespace, dict] + if _OMEGACONF_AVAILABLE: + hparams_container_types.append(DictConfig) + # empty container if len(args) == 1 and not isinstance(args, str) and not args[0]: - # args[0] is an empty container return - - if not frame: - current_frame = inspect.currentframe() - # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available - if current_frame: - frame = current_frame.f_back - if not isinstance(frame, types.FrameType): - raise AttributeError("There is no `frame` available while being required.") - - if is_dataclass(obj): - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} - else: - init_args = get_init_args(frame) - assert init_args, "failed to inspect the obj init" - - if ignore is not None: - if isinstance(ignore, str): - ignore = [ignore] - if isinstance(ignore, (list, tuple)): - ignore = [arg for arg in ignore if isinstance(arg, str)] - init_args = {k: v for k, v in init_args.items() if k not in ignore} - - if not args: - # take all arguments - hp = init_args - obj._hparams_name = "kwargs" if hp else None + # container + elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): + hp = args[0] + obj._hparams_name = "hparams" + obj._set_hparams(hp) + obj._hparams_initial = copy.deepcopy(obj._hparams) + return + # non-container args parsing else: - # take only listed arguments in `save_hparams` - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] - if len(isx_non_str) == 1: - hp = args[isx_non_str[0]] - cand_names = [k for k, v in init_args.items() if v == hp] - obj._hparams_name = cand_names[0] if cand_names else None + if not frame: + current_frame = inspect.currentframe() + # inspect.currentframe() return type is Optional[types.FrameType] + # current_frame.f_back called only if available + if current_frame: + frame = current_frame.f_back + if not isinstance(frame, types.FrameType): + raise AttributeError("There is no `frame` available while being required.") + + if is_dataclass(obj): + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} + else: + init_args = get_init_args(frame) + assert init_args, f"failed to inspect the obj init - {frame}" + + if ignore is not None: + if isinstance(ignore, str): + ignore = [ignore] + if isinstance(ignore, (list, tuple, set)): + ignore = [arg for arg in ignore if isinstance(arg, str)] + init_args = {k: v for k, v in init_args.items() if k not in ignore} + + if not args: + # take all arguments + hp = init_args + obj._hparams_name = "kwargs" if hp else None else: - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} - obj._hparams_name = "kwargs" + # take only listed arguments in `save_hparams` + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] + if len(isx_non_str) == 1: + hp = args[isx_non_str[0]] + cand_names = [k for k, v in init_args.items() if v == hp] + obj._hparams_name = cand_names[0] if cand_names else None + else: + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} + obj._hparams_name = "kwargs" # `hparams` are expected here if hp: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 512ad1ff1332f..c3bd24546f4f3 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from dataclasses import dataclass from typing import Any, Dict from unittest import mock @@ -20,6 +20,7 @@ import pytest import torch +from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -528,16 +529,33 @@ def test_dm_init_from_datasets_dataloaders(iterable): ) -class DataModuleWithHparams(LightningDataModule): +# all args +class DataModuleWithHparams_0(LightningDataModule): def __init__(self, arg0, arg1, kwarg0=None): super().__init__() self.save_hyperparameters() -def test_simple_hyperparameters_saving(): - data = DataModuleWithHparams(10, "foo", kwarg0="bar") +# single arg +class DataModuleWithHparams_1(LightningDataModule): + def __init__(self, arg0, *args, **kwargs): + super().__init__() + self.save_hyperparameters(arg0) + + +def test_hyperparameters_saving(): + data = DataModuleWithHparams_0(10, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) + data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"hello": "world"}) + + data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"hello": "world"}) + + data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") + assert data.hparams == OmegaConf.create({"hello": "world"}) + def test_define_as_dataclass(): # makes sure that no functionality is broken and the user can still manually make