Skip to content

Commit 59acf57

Browse files
kaushikb11SkafteNickicarmoccarohitgr7
authored
Add ignore param to save_hyperparameters (#6056)
* add ignore param to save_hyperparameters * add docstring for ignore * add type for frame object * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <[email protected]> * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <[email protected]> * fix whitespace * Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte <[email protected]> * Parametrize tests * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <[email protected]> * seq * fix docs * Update lightning.py * Update lightning.py * fix docs errors * add example keyword * update docstring Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 48a10f1 commit 59acf57

File tree

2 files changed

+108
-42
lines changed

2 files changed

+108
-42
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 72 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import logging
2020
import os
2121
import tempfile
22+
import types
2223
import uuid
2324
from abc import ABC
2425
from argparse import Namespace
2526
from functools import partial
2627
from pathlib import Path
27-
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
28+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
2829

2930
import torch
3031
from torch import ScriptModule, Tensor
@@ -1591,55 +1592,84 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
15911592
parents_arguments.update(args)
15921593
return self_arguments, parents_arguments
15931594

1594-
def save_hyperparameters(self, *args, frame=None) -> None:
1595-
"""Save all model arguments.
1595+
def save_hyperparameters(
1596+
self,
1597+
*args,
1598+
ignore: Optional[Union[Sequence[str], str]] = None,
1599+
frame: Optional[types.FrameType] = None
1600+
) -> None:
1601+
"""Save model arguments to ``hparams`` attribute.
15961602
15971603
Args:
15981604
args: single object of `dict`, `NameSpace` or `OmegaConf`
1599-
or string names or arguments from class `__init__`
1600-
1601-
>>> class ManuallyArgsModel(LightningModule):
1602-
... def __init__(self, arg1, arg2, arg3):
1603-
... super().__init__()
1604-
... # manually assign arguments
1605-
... self.save_hyperparameters('arg1', 'arg3')
1606-
... def forward(self, *args, **kwargs):
1607-
... ...
1608-
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1609-
>>> model.hparams
1610-
"arg1": 1
1611-
"arg3": 3.14
1612-
1613-
>>> class AutomaticArgsModel(LightningModule):
1614-
... def __init__(self, arg1, arg2, arg3):
1615-
... super().__init__()
1616-
... # equivalent automatic
1617-
... self.save_hyperparameters()
1618-
... def forward(self, *args, **kwargs):
1619-
... ...
1620-
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
1621-
>>> model.hparams
1622-
"arg1": 1
1623-
"arg2": abc
1624-
"arg3": 3.14
1625-
1626-
>>> class SingleArgModel(LightningModule):
1627-
... def __init__(self, params):
1628-
... super().__init__()
1629-
... # manually assign single argument
1630-
... self.save_hyperparameters(params)
1631-
... def forward(self, *args, **kwargs):
1632-
... ...
1633-
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1634-
>>> model.hparams
1635-
"p1": 1
1636-
"p2": abc
1637-
"p3": 3.14
1605+
or string names or arguments from class ``__init__``
1606+
ignore: an argument name or a list of argument names from
1607+
class ``__init__`` to be ignored
1608+
frame: a frame object. Default is None
1609+
1610+
Example::
1611+
>>> class ManuallyArgsModel(LightningModule):
1612+
... def __init__(self, arg1, arg2, arg3):
1613+
... super().__init__()
1614+
... # manually assign arguments
1615+
... self.save_hyperparameters('arg1', 'arg3')
1616+
... def forward(self, *args, **kwargs):
1617+
... ...
1618+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1619+
>>> model.hparams
1620+
"arg1": 1
1621+
"arg3": 3.14
1622+
1623+
>>> class AutomaticArgsModel(LightningModule):
1624+
... def __init__(self, arg1, arg2, arg3):
1625+
... super().__init__()
1626+
... # equivalent automatic
1627+
... self.save_hyperparameters()
1628+
... def forward(self, *args, **kwargs):
1629+
... ...
1630+
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
1631+
>>> model.hparams
1632+
"arg1": 1
1633+
"arg2": abc
1634+
"arg3": 3.14
1635+
1636+
>>> class SingleArgModel(LightningModule):
1637+
... def __init__(self, params):
1638+
... super().__init__()
1639+
... # manually assign single argument
1640+
... self.save_hyperparameters(params)
1641+
... def forward(self, *args, **kwargs):
1642+
... ...
1643+
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1644+
>>> model.hparams
1645+
"p1": 1
1646+
"p2": abc
1647+
"p3": 3.14
1648+
1649+
>>> class ManuallyArgsModel(LightningModule):
1650+
... def __init__(self, arg1, arg2, arg3):
1651+
... super().__init__()
1652+
... # pass argument(s) to ignore as a string or in a list
1653+
... self.save_hyperparameters(ignore='arg2')
1654+
... def forward(self, *args, **kwargs):
1655+
... ...
1656+
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
1657+
>>> model.hparams
1658+
"arg1": 1
1659+
"arg3": 3.14
16381660
"""
16391661
if not frame:
16401662
frame = inspect.currentframe().f_back
16411663
init_args = get_init_args(frame)
16421664
assert init_args, "failed to inspect the self init"
1665+
1666+
if ignore is not None:
1667+
if isinstance(ignore, str):
1668+
ignore = [ignore]
1669+
if isinstance(ignore, (list, tuple)):
1670+
ignore = [arg for arg in ignore if isinstance(arg, str)]
1671+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
1672+
16431673
if not args:
16441674
# take all arguments
16451675
hp = init_args

tests/models/test_hparams.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,39 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None):
661661
)
662662
trainer.fit(model)
663663
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
664+
665+
666+
@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3")))
667+
def test_ignore_args_list_hparams(tmpdir, ignore):
668+
"""
669+
Tests that args can be ignored in save_hyperparameters
670+
"""
671+
672+
class LocalModel(BoringModel):
673+
674+
def __init__(self, arg1, arg2, arg3):
675+
super().__init__()
676+
self.save_hyperparameters(ignore=ignore)
677+
678+
model = LocalModel(arg1=14, arg2=90, arg3=50)
679+
680+
# test proper property assignments
681+
assert model.hparams.arg1 == 14
682+
for arg in ignore:
683+
assert arg not in model.hparams
684+
685+
# verify we can train
686+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
687+
trainer.fit(model)
688+
689+
# make sure the raw checkpoint saved the properties
690+
raw_checkpoint_path = _raw_checkpoint_path(trainer)
691+
raw_checkpoint = torch.load(raw_checkpoint_path)
692+
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
693+
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
694+
695+
# verify that model loads correctly
696+
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
697+
assert model.hparams.arg1 == 14
698+
for arg in ignore:
699+
assert arg not in model.hparams

0 commit comments

Comments
 (0)