From 6a1c41939598206058ebac98bcb80b9a9534068e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 18 Feb 2021 18:16:45 +0530 Subject: [PATCH 01/17] add ignore param to save_hyperparameters --- pytorch_lightning/core/lightning.py | 22 ++++++++- tests/models/test_hparams.py | 73 +++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 58d045044d0b4..a56a81f75c5a0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1571,7 +1571,7 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters(self, *args, frame=None) -> None: + def save_hyperparameters(self, *args, ignore: Optional[Union[List, str]] = None, frame=None) -> None: """Save all model arguments. Args: @@ -1615,11 +1615,31 @@ def save_hyperparameters(self, *args, frame=None) -> None: "p1": 1 "p2": abc "p3": 3.14 + + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore="arg2") + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 """ if not frame: frame = inspect.currentframe().f_back init_args = get_init_args(frame) assert init_args, "failed to inspect the self init" + + if ignore: + if isinstance(ignore, str): + ignore = [ignore] + if isinstance(ignore, list): + ignore = [arg for arg in ignore if isinstance(arg, str)] + init_args = {k: v for k, v in init_args.items() if k not in ignore} + if not args: # take all arguments hp = init_args diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 0e32ebea09d85..970f6892e8fb0 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -695,3 +695,76 @@ def test_hparams(self): ) trainer.fit(model) _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) + + +def test_ignore_argument_str_hparams(tmpdir): + """ + Tests that a model can take regular args and assign & + ignore the argument that is mentioned. + """ + + # define model + class LocalModel(BoringModel): + + def __init__(self, test_arg, test_arg2, test_arg3): + super().__init__() + self.save_hyperparameters(ignore="test_arg2") + + model = LocalModel(test_arg=14, test_arg2=90, test_arg3=50) + + # test proper property assignments + assert model.hparams.test_arg == 14 + assert model.hparams.test_arg3 == 50 + assert "test_arg2" not in model.hparams + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 + + # verify that model loads correctly + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123) + assert model.hparams.test_arg == 14 + assert "test_arg2" not in model.hparams # test_arg2 is not registered in class init + + +def test_ignore_args_list_hparams(tmpdir): + """ + Tests that a model can take regular args and assign & + ignore the list of args that are mentioned. + """ + + # define model + class LocalModel(BoringModel): + + def __init__(self, test_arg, test_arg2, test_arg3): + super().__init__() + self.save_hyperparameters(ignore=["test_arg2", "test_arg3"]) + + model = LocalModel(test_arg=14, test_arg2=90, test_arg3=50) + + # test proper property assignments + assert model.hparams.test_arg == 14 + assert "test_arg2" not in model.hparams + assert "test_arg3" not in model.hparams + + # verify we can train + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) + trainer.fit(model) + + # make sure the raw checkpoint saved the properties + raw_checkpoint_path = _raw_checkpoint_path(trainer) + raw_checkpoint = torch.load(raw_checkpoint_path) + assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 + + # verify that model loads correctly + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123, test_arg3=100) + assert model.hparams.test_arg == 14 + assert "test_arg2" not in model.hparams + assert "test_arg3" not in model.hparams From abbc8ece2c3d0fb28054c868a3478126045b6229 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 18 Feb 2021 18:37:19 +0530 Subject: [PATCH 02/17] add docstring for ignore --- pytorch_lightning/core/lightning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a56a81f75c5a0..5774afd3266b5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1571,12 +1571,14 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters(self, *args, ignore: Optional[Union[List, str]] = None, frame=None) -> None: + def save_hyperparameters(self, *args, ignore: Optional[Union[List[str], str]] = None, frame=None) -> None: """Save all model arguments. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` - or string names or arguments from class `__init__` + or string names or arguments from class `__init__` + ignore: an argument or a list of arguments from class `__init__` + to be ignored >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): From 150a1d8a0c069dbc62674319c3eb2256035559df Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 18 Feb 2021 18:59:35 +0530 Subject: [PATCH 03/17] add type for frame object --- pytorch_lightning/core/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5774afd3266b5..8a64ede00ff37 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -19,6 +19,7 @@ import os import re import tempfile +import types import uuid from abc import ABC from argparse import Namespace @@ -1571,7 +1572,9 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: parents_arguments.update(args) return self_arguments, parents_arguments - def save_hyperparameters(self, *args, ignore: Optional[Union[List[str], str]] = None, frame=None) -> None: + def save_hyperparameters( + self, *args, ignore: Optional[Union[List[str], str]] = None, frame: Optional[types.FrameType] = None + ) -> None: """Save all model arguments. Args: @@ -1579,6 +1582,7 @@ def save_hyperparameters(self, *args, ignore: Optional[Union[List[str], str]] = or string names or arguments from class `__init__` ignore: an argument or a list of arguments from class `__init__` to be ignored + frame: a frame object. Default is None >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): From 5d6436daabdfaa44dcd47fb9424d0e6a7c3336dd Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 18 Feb 2021 20:01:13 +0530 Subject: [PATCH 04/17] Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8a64ede00ff37..402fe2cd72959 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1575,7 +1575,7 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: def save_hyperparameters( self, *args, ignore: Optional[Union[List[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: - """Save all model arguments. + """Save model arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` From 5dce21af02daaf7e85af4874c140270904f147cf Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 18 Feb 2021 20:01:22 +0530 Subject: [PATCH 05/17] Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte --- pytorch_lightning/core/lightning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 402fe2cd72959..25c0a2d484f32 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1583,6 +1583,8 @@ def save_hyperparameters( ignore: an argument or a list of arguments from class `__init__` to be ignored frame: a frame object. Default is None + + Example:: >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): From a8774aef086ddfc6aec36a3db60956e9ba55b6c2 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 18 Feb 2021 20:05:00 +0530 Subject: [PATCH 06/17] fix whitespace --- pytorch_lightning/core/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 25c0a2d484f32..026f78d8ba9fb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1583,9 +1583,8 @@ def save_hyperparameters( ignore: an argument or a list of arguments from class `__init__` to be ignored frame: a frame object. Default is None - - Example:: + Example:: >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() From 49171955717058ddb8d7f429e3072064b7a2ecf7 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 18 Feb 2021 20:26:11 +0530 Subject: [PATCH 07/17] Update pytorch_lightning/core/lightning.py Co-authored-by: Nicki Skafte --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 026f78d8ba9fb..eba51022154f4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1640,7 +1640,7 @@ def save_hyperparameters( init_args = get_init_args(frame) assert init_args, "failed to inspect the self init" - if ignore: + if ignore is not None: if isinstance(ignore, str): ignore = [ignore] if isinstance(ignore, list): From 5bd3f6340d257c9afa30c44d095d2001721ba8de Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 18 Feb 2021 16:24:29 +0100 Subject: [PATCH 08/17] Parametrize tests --- tests/models/test_hparams.py | 65 ++++++++---------------------------- 1 file changed, 14 insertions(+), 51 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 970f6892e8fb0..245621f77d2d8 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -697,25 +697,24 @@ def test_hparams(self): _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) -def test_ignore_argument_str_hparams(tmpdir): +@pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3"))) +def test_ignore_args_list_hparams(tmpdir, ignore): """ - Tests that a model can take regular args and assign & - ignore the argument that is mentioned. + Tests that args can be ignored in save_hyperparameters """ - # define model class LocalModel(BoringModel): - def __init__(self, test_arg, test_arg2, test_arg3): + def __init__(self, arg1, arg2, arg3): super().__init__() - self.save_hyperparameters(ignore="test_arg2") + self.save_hyperparameters(ignore=ignore) - model = LocalModel(test_arg=14, test_arg2=90, test_arg3=50) + model = LocalModel(arg1=14, arg2=90, arg3=50) # test proper property assignments - assert model.hparams.test_arg == 14 - assert model.hparams.test_arg3 == 50 - assert "test_arg2" not in model.hparams + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) @@ -725,46 +724,10 @@ def __init__(self, test_arg, test_arg2, test_arg3): raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint - assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 + assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14 # verify that model loads correctly - model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123) - assert model.hparams.test_arg == 14 - assert "test_arg2" not in model.hparams # test_arg2 is not registered in class init - - -def test_ignore_args_list_hparams(tmpdir): - """ - Tests that a model can take regular args and assign & - ignore the list of args that are mentioned. - """ - - # define model - class LocalModel(BoringModel): - - def __init__(self, test_arg, test_arg2, test_arg3): - super().__init__() - self.save_hyperparameters(ignore=["test_arg2", "test_arg3"]) - - model = LocalModel(test_arg=14, test_arg2=90, test_arg3=50) - - # test proper property assignments - assert model.hparams.test_arg == 14 - assert "test_arg2" not in model.hparams - assert "test_arg3" not in model.hparams - - # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) - trainer.fit(model) - - # make sure the raw checkpoint saved the properties - raw_checkpoint_path = _raw_checkpoint_path(trainer) - raw_checkpoint = torch.load(raw_checkpoint_path) - assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint - assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 - - # verify that model loads correctly - model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123, test_arg3=100) - assert model.hparams.test_arg == 14 - assert "test_arg2" not in model.hparams - assert "test_arg3" not in model.hparams + model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100) + assert model.hparams.arg1 == 14 + for arg in ignore: + assert arg not in model.hparams From ebe724ae889f21a21d12fbfcb40fb6c64d8323e1 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 18 Feb 2021 23:48:22 +0530 Subject: [PATCH 09/17] Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index eba51022154f4..9168f39e28048 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1573,7 +1573,7 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: return self_arguments, parents_arguments def save_hyperparameters( - self, *args, ignore: Optional[Union[List[str], str]] = None, frame: Optional[types.FrameType] = None + self, *args, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: """Save model arguments to ``hparams`` attribute. From 7824e8ab779be2f9fff85b1c96de174783e6ac3c Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 18 Feb 2021 23:48:30 +0530 Subject: [PATCH 10/17] Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9168f39e28048..f7ed725b223ed 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1643,7 +1643,7 @@ def save_hyperparameters( if ignore is not None: if isinstance(ignore, str): ignore = [ignore] - if isinstance(ignore, list): + if isinstance(ignore, (list, tuple)): ignore = [arg for arg in ignore if isinstance(arg, str)] init_args = {k: v for k, v in init_args.items() if k not in ignore} From 31fe2e88cd960147771e5569eed25a889bee9735 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 18 Feb 2021 23:51:43 +0530 Subject: [PATCH 11/17] seq --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f7ed725b223ed..d90e635fc2a59 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -25,7 +25,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor From 532facf7a3fb4b79058eb22d88f2690bb42d4b37 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Feb 2021 00:10:49 +0530 Subject: [PATCH 12/17] fix docs --- pytorch_lightning/core/lightning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 026f78d8ba9fb..d5c6009fbe2a2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1579,8 +1579,8 @@ def save_hyperparameters( Args: args: single object of `dict`, `NameSpace` or `OmegaConf` - or string names or arguments from class `__init__` - ignore: an argument or a list of arguments from class `__init__` + or string names or arguments from class ``__init__`` + ignore: an argument or a list of arguments from class ``__init__`` to be ignored frame: a frame object. Default is None @@ -1627,7 +1627,7 @@ def save_hyperparameters( ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # pass argument(s) to ignore as a string or in a list - ... self.save_hyperparameters(ignore="arg2") + ... self.save_hyperparameters(ignore='arg2') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) From 4a53fab02e9c1c6b932d8eb12339d2e37df5c95a Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 19 Feb 2021 14:19:20 +0530 Subject: [PATCH 13/17] Update lightning.py --- pytorch_lightning/core/lightning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bb62844926389..57252465c75d0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1600,7 +1600,6 @@ def save_hyperparameters( >>> class AutomaticArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() - ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... From 6baff9ed2778fdff0211322a9312dcb48b0d6066 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 19 Feb 2021 14:24:49 +0530 Subject: [PATCH 14/17] Update lightning.py --- pytorch_lightning/core/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 57252465c75d0..bb62844926389 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1600,6 +1600,7 @@ def save_hyperparameters( >>> class AutomaticArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() + ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... From 7a055775f43567cd4de0c38f052a654fe7a5d7a9 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Feb 2021 14:32:04 +0530 Subject: [PATCH 15/17] fix docs errors --- pytorch_lightning/core/lightning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bb62844926389..7df15d69480dc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1573,7 +1573,10 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: return self_arguments, parents_arguments def save_hyperparameters( - self, *args, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None + self, + *args, + ignore: Optional[Union[Sequence[str], str]] = None, + frame: Optional[types.FrameType] = None ) -> None: """Save model arguments to ``hparams`` attribute. @@ -1584,7 +1587,6 @@ def save_hyperparameters( to be ignored frame: a frame object. Default is None - Example:: >>> class ManuallyArgsModel(LightningModule): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() From 22b05416451eeaa5a2c0262e9c6087a4565ceaaf Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Feb 2021 16:14:55 +0530 Subject: [PATCH 16/17] add example keyword --- pytorch_lightning/core/lightning.py | 99 +++++++++++++++-------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7df15d69480dc..90c98c0d2427d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1587,55 +1587,56 @@ def save_hyperparameters( to be ignored frame: a frame object. Default is None - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # manually assign arguments - ... self.save_hyperparameters('arg1', 'arg3') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 - - >>> class AutomaticArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # equivalent automatic - ... self.save_hyperparameters() - ... def forward(self, *args, **kwargs): - ... ... - >>> model = AutomaticArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg2": abc - "arg3": 3.14 - - >>> class SingleArgModel(LightningModule): - ... def __init__(self, params): - ... super().__init__() - ... # manually assign single argument - ... self.save_hyperparameters(params) - ... def forward(self, *args, **kwargs): - ... ... - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) - >>> model.hparams - "p1": 1 - "p2": abc - "p3": 3.14 - - >>> class ManuallyArgsModel(LightningModule): - ... def __init__(self, arg1, arg2, arg3): - ... super().__init__() - ... # pass argument(s) to ignore as a string or in a list - ... self.save_hyperparameters(ignore='arg2') - ... def forward(self, *args, **kwargs): - ... ... - >>> model = ManuallyArgsModel(1, 'abc', 3.14) - >>> model.hparams - "arg1": 1 - "arg3": 3.14 + Example:: + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # manually assign arguments + ... self.save_hyperparameters('arg1', 'arg3') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 + + >>> class AutomaticArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # equivalent automatic + ... self.save_hyperparameters() + ... def forward(self, *args, **kwargs): + ... ... + >>> model = AutomaticArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg2": abc + "arg3": 3.14 + + >>> class SingleArgModel(LightningModule): + ... def __init__(self, params): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(params) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.hparams + "p1": 1 + "p2": abc + "p3": 3.14 + + >>> class ManuallyArgsModel(LightningModule): + ... def __init__(self, arg1, arg2, arg3): + ... super().__init__() + ... # pass argument(s) to ignore as a string or in a list + ... self.save_hyperparameters(ignore='arg2') + ... def forward(self, *args, **kwargs): + ... ... + >>> model = ManuallyArgsModel(1, 'abc', 3.14) + >>> model.hparams + "arg1": 1 + "arg3": 3.14 """ if not frame: frame = inspect.currentframe().f_back From 85c2e44195c4c45740bbb5f71443f85ffbeb54b3 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Feb 2021 18:20:32 +0530 Subject: [PATCH 17/17] update docstring --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 90c98c0d2427d..bf6dd2f8991c7 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1583,8 +1583,8 @@ def save_hyperparameters( Args: args: single object of `dict`, `NameSpace` or `OmegaConf` or string names or arguments from class ``__init__`` - ignore: an argument or a list of arguments from class ``__init__`` - to be ignored + ignore: an argument name or a list of argument names from + class ``__init__`` to be ignored frame: a frame object. Default is None Example::