Skip to content

Commit 09ff295

Browse files
tilman151BordaethanwharrisTilman Krokotschrohitgr7
authored
Hyperparameters for datamodule (#3792)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Tilman Krokotsch <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent 3102922 commit 09ff295

File tree

7 files changed

+248
-118
lines changed

7 files changed

+248
-118
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
149149
- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))
150150

151151

152+
- Added support for `save_hyperparameters` in `LightningDataModule` ([#3792](https://github.com/PyTorchLightning/pytorch-lightning/pull/3792))
153+
154+
152155
- Added `LSFEnvironment` for distributed training with the LSF resource manager `jsrun` ([#5102](https://github.com/PyTorchLightning/pytorch-lightning/pull/5102))
153156

154157

pytorch_lightning/core/datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.utilities import rank_zero_deprecation
2424
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
25+
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin
2526

2627

27-
class LightningDataModule(CheckpointHooks, DataHooks):
28+
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
2829
"""
2930
A DataModule standardizes the training, val, test splits, data preparation and transforms.
3031
The main advantage is consistent data splits, data preparation and transforms across models.

pytorch_lightning/core/lightning.py

Lines changed: 6 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414
"""The LightningModule - an nn.Module with many additional features."""
1515

1616
import collections
17-
import copy
1817
import inspect
1918
import logging
2019
import numbers
2120
import os
2221
import tempfile
23-
import types
2422
import uuid
2523
from abc import ABC
26-
from argparse import Namespace
2724
from pathlib import Path
28-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
25+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
2926

3027
import numpy as np
3128
import torch
@@ -38,15 +35,16 @@
3835
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
3936
from pytorch_lightning.core.memory import ModelSummary
4037
from pytorch_lightning.core.optimizer import LightningOptimizer
41-
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
38+
from pytorch_lightning.core.saving import ModelIO
4239
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
4340
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
4441
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
4542
from pytorch_lightning.utilities.cloud_io import get_filesystem
4643
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
4744
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
4845
from pytorch_lightning.utilities.exceptions import MisconfigurationException
49-
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
46+
from pytorch_lightning.utilities.hparams_mixin import HyperparametersMixin
47+
from pytorch_lightning.utilities.parsing import collect_init_args
5048
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
5149
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
5250
from pytorch_lightning.utilities.warnings import WarningCache
@@ -58,6 +56,7 @@
5856
class LightningModule(
5957
ABC,
6058
DeviceDtypeModuleMixin,
59+
HyperparametersMixin,
6160
GradInformation,
6261
ModelIO,
6362
ModelHooks,
@@ -70,8 +69,6 @@ class LightningModule(
7069
__jit_unused_properties__ = [
7170
"datamodule",
7271
"example_input_array",
73-
"hparams",
74-
"hparams_initial",
7572
"on_gpu",
7673
"current_epoch",
7774
"global_step",
@@ -82,7 +79,7 @@ class LightningModule(
8279
"automatic_optimization",
8380
"truncated_bptt_steps",
8481
"loaded_optimizer_states_dict",
85-
] + DeviceDtypeModuleMixin.__jit_unused_properties__
82+
] + DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__
8683

8784
def __init__(self, *args: Any, **kwargs: Any) -> None:
8885
super().__init__(*args, **kwargs)
@@ -1832,92 +1829,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
18321829
parents_arguments.update(args)
18331830
return self_arguments, parents_arguments
18341831

1835-
def save_hyperparameters(
1836-
self,
1837-
*args,
1838-
ignore: Optional[Union[Sequence[str], str]] = None,
1839-
frame: Optional[types.FrameType] = None
1840-
) -> None:
1841-
"""Save model arguments to the ``hparams`` attribute.
1842-
1843-
Args:
1844-
args: single object of type :class:`dict`, :class:`~argparse.Namespace`, `OmegaConf`
1845-
or strings representing the argument names in ``__init__``.
1846-
ignore: an argument name or a list of argument names in ``__init__`` to be ignored
1847-
frame: a frame object. Default is ``None``.
1848-
1849-
Example::
1850-
1851-
>>> class ManuallyArgsModel(LightningModule):
1852-
... def __init__(self, arg1, arg2, arg3):
1853-
... super().__init__()
1854-
... # manually assign arguments
1855-
... self.save_hyperparameters('arg1', 'arg3')
1856-
... def forward(self, *args, **kwargs):
1857-
... ...
1858-
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1859-
>>> model.hparams
1860-
"arg1": 1
1861-
"arg3": 3.14
1862-
1863-
>>> class AutomaticArgsModel(LightningModule):
1864-
... def __init__(self, arg1, arg2, arg3):
1865-
... super().__init__()
1866-
... # equivalent automatic
1867-
... self.save_hyperparameters()
1868-
... def forward(self, *args, **kwargs):
1869-
... ...
1870-
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
1871-
>>> model.hparams
1872-
"arg1": 1
1873-
"arg2": abc
1874-
"arg3": 3.14
1875-
1876-
>>> class SingleArgModel(LightningModule):
1877-
... def __init__(self, params):
1878-
... super().__init__()
1879-
... # manually assign single argument
1880-
... self.save_hyperparameters(params)
1881-
... def forward(self, *args, **kwargs):
1882-
... ...
1883-
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1884-
>>> model.hparams
1885-
"p1": 1
1886-
"p2": abc
1887-
"p3": 3.14
1888-
1889-
>>> class ManuallyArgsModel(LightningModule):
1890-
... def __init__(self, arg1, arg2, arg3):
1891-
... super().__init__()
1892-
... # pass argument(s) to ignore as a string or in a list
1893-
... self.save_hyperparameters(ignore='arg2')
1894-
... def forward(self, *args, **kwargs):
1895-
... ...
1896-
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1897-
>>> model.hparams
1898-
"arg1": 1
1899-
"arg3": 3.14
1900-
"""
1901-
# the frame needs to be created in this file.
1902-
if not frame:
1903-
frame = inspect.currentframe().f_back
1904-
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
1905-
1906-
def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
1907-
if isinstance(hp, Namespace):
1908-
hp = vars(hp)
1909-
if isinstance(hp, dict):
1910-
hp = AttributeDict(hp)
1911-
elif isinstance(hp, PRIMITIVE_TYPES):
1912-
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
1913-
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
1914-
raise ValueError(f"Unsupported config type of {type(hp)}.")
1915-
1916-
if isinstance(hp, dict) and isinstance(self.hparams, dict):
1917-
self.hparams.update(hp)
1918-
else:
1919-
self._hparams = hp
1920-
19211832
@torch.no_grad()
19221833
def to_onnx(
19231834
self,
@@ -2049,27 +1960,6 @@ def to_torchscript(
20491960

20501961
return torchscript_module
20511962

2052-
@property
2053-
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
2054-
"""
2055-
The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
2056-
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
2057-
"""
2058-
if not hasattr(self, "_hparams"):
2059-
self._hparams = AttributeDict()
2060-
return self._hparams
2061-
2062-
@property
2063-
def hparams_initial(self) -> AttributeDict:
2064-
"""
2065-
The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only.
2066-
Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`.
2067-
"""
2068-
if not hasattr(self, "_hparams_initial"):
2069-
return AttributeDict()
2070-
# prevent any change
2071-
return copy.deepcopy(self._hparams_initial)
2072-
20731963
@property
20741964
def model_size(self) -> float:
20751965
"""

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,11 +903,24 @@ def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT,
903903

904904
def _pre_dispatch(self):
905905
self.accelerator.pre_dispatch(self)
906+
self._log_hyperparams()
906907

908+
def _log_hyperparams(self):
907909
# log hyper-parameters
908910
if self.logger is not None:
909911
# save exp to get started (this is where the first experiment logs are written)
910-
self.logger.log_hyperparams(self.lightning_module.hparams_initial)
912+
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
913+
lightning_hparams = self.lightning_module.hparams_initial
914+
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
915+
if colliding_keys:
916+
raise MisconfigurationException(
917+
f"Error while merging hparams: the keys {colliding_keys} are present "
918+
"in both the LightningModule's and LightningDataModule's hparams."
919+
)
920+
921+
hparams_initial = {**lightning_hparams, **datamodule_hparams}
922+
923+
self.logger.log_hyperparams(hparams_initial)
911924
self.logger.log_graph(self.lightning_module)
912925
self.logger.save()
913926

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import copy
15+
import inspect
16+
import types
17+
from argparse import Namespace
18+
from typing import Optional, Sequence, Union
19+
20+
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
21+
from pytorch_lightning.utilities import AttributeDict
22+
from pytorch_lightning.utilities.parsing import save_hyperparameters
23+
24+
25+
class HyperparametersMixin:
26+
27+
__jit_unused_properties__ = ["hparams", "hparams_initial"]
28+
29+
def save_hyperparameters(
30+
self,
31+
*args,
32+
ignore: Optional[Union[Sequence[str], str]] = None,
33+
frame: Optional[types.FrameType] = None
34+
) -> None:
35+
"""Save arguments to ``hparams`` attribute.
36+
37+
Args:
38+
args: single object of `dict`, `NameSpace` or `OmegaConf`
39+
or string names or arguments from class ``__init__``
40+
ignore: an argument name or a list of argument names from
41+
class ``__init__`` to be ignored
42+
frame: a frame object. Default is None
43+
44+
Example::
45+
>>> class ManuallyArgsModel(HyperparametersMixin):
46+
... def __init__(self, arg1, arg2, arg3):
47+
... super().__init__()
48+
... # manually assign arguments
49+
... self.save_hyperparameters('arg1', 'arg3')
50+
... def forward(self, *args, **kwargs):
51+
... ...
52+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
53+
>>> model.hparams
54+
"arg1": 1
55+
"arg3": 3.14
56+
57+
>>> class AutomaticArgsModel(HyperparametersMixin):
58+
... def __init__(self, arg1, arg2, arg3):
59+
... super().__init__()
60+
... # equivalent automatic
61+
... self.save_hyperparameters()
62+
... def forward(self, *args, **kwargs):
63+
... ...
64+
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
65+
>>> model.hparams
66+
"arg1": 1
67+
"arg2": abc
68+
"arg3": 3.14
69+
70+
>>> class SingleArgModel(HyperparametersMixin):
71+
... def __init__(self, params):
72+
... super().__init__()
73+
... # manually assign single argument
74+
... self.save_hyperparameters(params)
75+
... def forward(self, *args, **kwargs):
76+
... ...
77+
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
78+
>>> model.hparams
79+
"p1": 1
80+
"p2": abc
81+
"p3": 3.14
82+
83+
>>> class ManuallyArgsModel(HyperparametersMixin):
84+
... def __init__(self, arg1, arg2, arg3):
85+
... super().__init__()
86+
... # pass argument(s) to ignore as a string or in a list
87+
... self.save_hyperparameters(ignore='arg2')
88+
... def forward(self, *args, **kwargs):
89+
... ...
90+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
91+
>>> model.hparams
92+
"arg1": 1
93+
"arg3": 3.14
94+
"""
95+
# the frame needs to be created in this file.
96+
if not frame:
97+
frame = inspect.currentframe().f_back
98+
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
99+
100+
def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
101+
hp = self._to_hparams_dict(hp)
102+
103+
if isinstance(hp, dict) and isinstance(self.hparams, dict):
104+
self.hparams.update(hp)
105+
else:
106+
self._hparams = hp
107+
108+
@staticmethod
109+
def _to_hparams_dict(hp: Union[dict, Namespace, str]):
110+
if isinstance(hp, Namespace):
111+
hp = vars(hp)
112+
if isinstance(hp, dict):
113+
hp = AttributeDict(hp)
114+
elif isinstance(hp, PRIMITIVE_TYPES):
115+
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
116+
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
117+
raise ValueError(f"Unsupported config type of {type(hp)}.")
118+
return hp
119+
120+
@property
121+
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
122+
if not hasattr(self, "_hparams"):
123+
self._hparams = AttributeDict()
124+
return self._hparams
125+
126+
@property
127+
def hparams_initial(self) -> AttributeDict:
128+
if not hasattr(self, "_hparams_initial"):
129+
return AttributeDict()
130+
# prevent any change
131+
return copy.deepcopy(self._hparams_initial)

tests/core/test_datamodules.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pytorch_lightning import LightningDataModule, Trainer
2424
from pytorch_lightning.callbacks import ModelCheckpoint
25+
from pytorch_lightning.utilities import AttributeDict
2526
from pytorch_lightning.utilities.model_helpers import is_overridden
2627
from tests.helpers import BoringDataModule, BoringModel
2728
from tests.helpers.datamodules import ClassifDataModule
@@ -551,3 +552,15 @@ def test_dm_init_from_datasets_dataloaders(iterable):
551552
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
552553
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
553554
])
555+
556+
557+
class DataModuleWithHparams(LightningDataModule):
558+
559+
def __init__(self, arg0, arg1, kwarg0=None):
560+
super().__init__()
561+
self.save_hyperparameters()
562+
563+
564+
def test_simple_hyperparameters_saving():
565+
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
566+
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})

0 commit comments

Comments
 (0)