diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 57aa264244a68..f47c5a7d3a1fb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -19,12 +19,13 @@ import os import re import tempfile +import types import uuid from abc import ABC from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -1582,55 +1583,84 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters(self, *args, frame=None) -> None: - """Save all model arguments. + def save_hyperparameters( + self, + *args, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None + ) -> None: + """Save model arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` - or string names or arguments from class `__init__` - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(LightningModule): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 + or string names or arguments from class ``__init__`` + ignore: an argument name or a list of argument names from + class ``__init__`` to be ignored + frame: a frame object. Default is None + + Example:: + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + + >>> class AutomaticArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + + >>> class SingleArgModel(LightningModule): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore='arg2') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 """ if not frame: frame = inspect.currentframe().f_back init_args = get_init_args(frame) assert init_args, "failed to inspect the self init" + + if ignore is not None: + if isinstance(ignore, str): + ignore = [ignore] + if isinstance(ignore, (list, tuple)): + ignore = [arg for arg in ignore if isinstance(arg, str)] + init_args = {k: v for k, v in init_args.items() if k not in ignore} + if not args: # take all arguments hp = init_args diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 0e32ebea09d85..245621f77d2d8 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -695,3 +695,39 @@ def test_hparams(self): ) trainer.fit(model) _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + + +@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3"))) +def test_ignore_args_list_hparams(tmpdir, ignore): + """ + Tests that args can be ignored in save_hyperparameters + """ + + class LocalModel(BoringModel): + + def __init__(self, arg1, arg2, arg3): + super().__init__() + self.save_hyperparameters(ignore=ignore) + + model = LocalModel(arg1=14, arg2=90, arg3=50) + + # test proper property assignments + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14 + + # verify that model loads correctly + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100) + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams