diff --git a/CHANGELOG.md b/CHANGELOG.md index a5a9f88de72da..a91f7fc3189ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307)) +- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792)) + + - Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102)) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 4b62198542118..b1f42c6ea4390 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -22,9 +22,10 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin -class LightningDataModule(CheckpointHooks, DataHooks): +class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3ec5fe0e8ffb1..735f8ab160c1f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,18 +14,15 @@ """The LightningModule - an nn.Module with many additional features.""" import collections -import copy import inspect import logging import numbers import os import tempfile -import types import uuid from abc import ABC -from argparse import Namespace from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import numpy as np import torch @@ -38,7 +35,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors @@ -46,7 +43,8 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters +from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin +from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -58,6 +56,7 @@ class LightningModule( ABC, DeviceDtypeModuleMixin, + HyperparametersMixin, GradInformation, ModelIO, ModelHooks, @@ -70,8 +69,6 @@ class LightningModule( __jit_unused_properties__ = [ "datamodule", "example_input_array", - "hparams", - "hparams_initial", "on_gpu", "current_epoch", "global_step", @@ -82,7 +79,7 @@ class LightningModule( "automatic_optimization", "truncated_bptt_steps", "loaded_optimizer_states_dict", - ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + ] + DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1832,92 +1829,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters( - self, - *args, - ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None - ) -> None: - """Save model arguments to the ``hparams`` attribute. - - Args: - args: single object of type :class:`dict`, :class:`~argparse.Namespace`, `OmegaConf` - or strings representing the argument names in ``__init__``. - ignore: an argument name or a list of argument names in ``__init__`` to be ignored - frame: a frame object. Default is ``None``. - - Example:: - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(LightningModule): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # pass argument(s) to ignore as a string or in a list - ... self.save_hyperparameters(ignore='arg2') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - """ - # the frame needs to be created in this file. - if not frame: - frame = inspect.currentframe().f_back - save_hyperparameters(self, *args, ignore=ignore, frame=frame) - - def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: - if isinstance(hp, Namespace): - hp = vars(hp) - if isinstance(hp, dict): - hp = AttributeDict(hp) - elif isinstance(hp, PRIMITIVE_TYPES): - raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") - elif not isinstance(hp, ALLOWED_CONFIG_TYPES): - raise ValueError(f"Unsupported config type of {type(hp)}.") - - if isinstance(hp, dict) and isinstance(self.hparams, dict): - self.hparams.update(hp) - else: - self._hparams = hp - @torch.no_grad() def to_onnx( self, @@ -2049,27 +1960,6 @@ def to_torchscript( return torchscript_module - @property - def hparams(self) -> Union[AttributeDict, dict, Namespace]: - """ - The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. - For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. - """ - if not hasattr(self, "_hparams"): - self._hparams = AttributeDict() - return self._hparams - - @property - def hparams_initial(self) -> AttributeDict: - """ - The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. - Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`. - """ - if not hasattr(self, "_hparams_initial"): - return AttributeDict() - # prevent any change - return copy.deepcopy(self._hparams_initial) - @property def model_size(self) -> float: """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7475cd9c81326..4c034ac843361 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -903,11 +903,24 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, def _pre_dispatch(self): self.accelerator.pre_dispatch(self) + self._log_hyperparams() + def _log_hyperparams(self): # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - self.logger.log_hyperparams(self.lightning_module.hparams_initial) + datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} + lightning_hparams = self.lightning_module.hparams_initial + colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() + if colliding_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {colliding_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams." + ) + + hparams_initial = {**lightning_hparams, **datamodule_hparams} + + self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py new file mode 100644 index 0000000000000..8dd4b23c89398 --- /dev/null +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import inspect +import types +from argparse import Namespace +from typing import Optional, Sequence, Union + +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES +from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.parsing import save_hyperparameters + + +class HyperparametersMixin: + + __jit_unused_properties__ = ["hparams", "hparams_initial"] + + def save_hyperparameters( + self, + *args, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None + ) -> None: + """Save arguments to ``hparams`` attribute. + + Args: + args: single object of `dict`, `NameSpace` or `OmegaConf` + or string names or arguments from class ``__init__`` + ignore: an argument name or a list of argument names from + class ``__init__`` to be ignored + frame: a frame object. Default is None + + Example:: + >>> class ManuallyArgsModel(HyperparametersMixin): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + + >>> class AutomaticArgsModel(HyperparametersMixin): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + + >>> class SingleArgModel(HyperparametersMixin): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + + >>> class ManuallyArgsModel(HyperparametersMixin): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore='arg2') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + """ + # the frame needs to be created in this file. + if not frame: + frame = inspect.currentframe().f_back + save_hyperparameters(self, *args, ignore=ignore, frame=frame) + + def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: + hp = self._to_hparams_dict(hp) + + if isinstance(hp, dict) and isinstance(self.hparams, dict): + self.hparams.update(hp) + else: + self._hparams = hp + + @staticmethod + def _to_hparams_dict(hp: Union[dict, Namespace, str]): + if isinstance(hp, Namespace): + hp = vars(hp) + if isinstance(hp, dict): + hp = AttributeDict(hp) + elif isinstance(hp, PRIMITIVE_TYPES): + raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.") + elif not isinstance(hp, ALLOWED_CONFIG_TYPES): + raise ValueError(f"Unsupported config type of {type(hp)}.") + return hp + + @property + def hparams(self) -> Union[AttributeDict, dict, Namespace]: + if not hasattr(self, "_hparams"): + self._hparams = AttributeDict() + return self._hparams + + @property + def hparams_initial(self) -> AttributeDict: + if not hasattr(self, "_hparams_initial"): + return AttributeDict() + # prevent any change + return copy.deepcopy(self._hparams_initial) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c056ab1aa4fbf..6203e93e63e2f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -22,6 +22,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -551,3 +552,15 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) + + +class DataModuleWithHparams(LightningDataModule): + + def __init__(self, arg0, arg1, kwarg0=None): + super().__init__() + self.save_hyperparameters() + + +def test_simple_hyperparameters_saving(): + data = DataModuleWithHparams(10, "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7fa8872036a73..1ba92bd754d60 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import functools import os import pickle from argparse import Namespace from dataclasses import dataclass +from unittest import mock import cloudpickle import pytest @@ -26,8 +28,10 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -738,3 +742,78 @@ def test_dataclass_lightning_module(tmpdir): """ Test that save_hyperparameters() works with a LightningModule as a dataclass. """ model = DataClassModel(33, optional="cocofruit") assert model.hparams == dict(mandatory=33, optional="cocofruit") + + +class NoHparamsModel(BoringModel): + """ Tests a model without hparams. """ + + +class DataModuleWithoutHparams(LightningDataModule): + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(RandomDataset(32, 64), batch_size=32) + + +class DataModuleWithHparams(LightningDataModule): + + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + return DataLoader(RandomDataset(32, 64), batch_size=32) + + +def _get_mock_logger(tmpdir): + mock_logger = mock.MagicMock(name="logger") + mock_logger.name = "mock_logger" + mock_logger.save_dir = tmpdir + mock_logger.version = "0" + del mock_logger.__iter__ + return mock_logger + + +@pytest.mark.parametrize("model", (SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}), NoHparamsModel())) +@pytest.mark.parametrize("data", (DataModuleWithHparams({'data_dir': 'foo'}), DataModuleWithoutHparams())) +def test_adding_datamodule_hparams(tmpdir, model, data): + """Test that hparams from datamodule and model are logged.""" + org_model_hparams = copy.deepcopy(model.hparams_initial) + org_data_hparams = copy.deepcopy(data.hparams_initial) + + mock_logger = _get_mock_logger(tmpdir) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) + trainer.fit(model, datamodule=data) + + # Hparams of model and data were not modified + assert org_model_hparams == model.hparams + assert org_data_hparams == data.hparams + + # Merged hparams were logged + merged_hparams = copy.deepcopy(org_model_hparams) + merged_hparams.update(org_data_hparams) + mock_logger.log_hyperparams.assert_called_with(merged_hparams) + + +def test_no_datamodule_for_hparams(tmpdir): + """Test that hparams model are logged if no datamodule is used.""" + model = SaveHparamsModel({'arg1': 5, 'arg2': 'abc'}) + org_model_hparams = copy.deepcopy(model.hparams_initial) + data = DataModuleWithoutHparams() + data.setup() + + mock_logger = _get_mock_logger(tmpdir) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=mock_logger) + trainer.fit(model, datamodule=data) + + # Merged hparams were logged + mock_logger.log_hyperparams.assert_called_with(org_model_hparams) + + +def test_colliding_hparams(tmpdir): + + model = SaveHparamsModel({'data_dir': 'abc', 'arg2': 'abc'}) + data = DataModuleWithHparams({'data_dir': 'foo'}) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + with pytest.raises(MisconfigurationException, match=r'Error while merging hparams:'): + trainer.fit(model, datamodule=data)