-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add ignore param to save_hyperparameters #6056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a1c419
abbc8ec
150a1d8
5d6436d
5dce21a
a8774ae
4917195
5bd3f63
ebe724a
7824e8a
31fe2e8
532facf
110dbb4
4a53fab
6baff9e
7a05577
22b0541
85c2e44
2c5efa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,12 +19,13 @@ | |
| import os | ||
| import re | ||
| import tempfile | ||
| import types | ||
| import uuid | ||
| from abc import ABC | ||
| from argparse import Namespace | ||
| from functools import partial | ||
| from pathlib import Path | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union | ||
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union | ||
|
|
||
| import torch | ||
| from torch import ScriptModule, Tensor | ||
|
|
@@ -1582,55 +1583,84 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: | |
| parents_arguments.update(args) | ||
| return self_arguments, parents_arguments | ||
|
|
||
| def save_hyperparameters(self, *args, frame=None) -> None: | ||
| """Save all model arguments. | ||
| def save_hyperparameters( | ||
| self, | ||
| *args, | ||
| ignore: Optional[Union[Sequence[str], str]] = None, | ||
| frame: Optional[types.FrameType] = None | ||
| ) -> None: | ||
| """Save model arguments to ``hparams`` attribute. | ||
|
|
||
| Args: | ||
| args: single object of `dict`, `NameSpace` or `OmegaConf` | ||
| or string names or arguments from class `__init__` | ||
|
|
||
| >>> class ManuallyArgsModel(LightningModule): | ||
| ... def __init__(self, arg1, arg2, arg3): | ||
| ... super().__init__() | ||
| ... # manually assign arguments | ||
| ... self.save_hyperparameters('arg1', 'arg3') | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = ManuallyArgsModel(1, 'abc', 3.14) | ||
| >>> model.hparams | ||
| "arg1": 1 | ||
| "arg3": 3.14 | ||
|
|
||
| >>> class AutomaticArgsModel(LightningModule): | ||
| ... def __init__(self, arg1, arg2, arg3): | ||
| ... super().__init__() | ||
| ... # equivalent automatic | ||
| ... self.save_hyperparameters() | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = AutomaticArgsModel(1, 'abc', 3.14) | ||
| >>> model.hparams | ||
| "arg1": 1 | ||
| "arg2": abc | ||
| "arg3": 3.14 | ||
|
|
||
| >>> class SingleArgModel(LightningModule): | ||
| ... def __init__(self, params): | ||
| ... super().__init__() | ||
| ... # manually assign single argument | ||
| ... self.save_hyperparameters(params) | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) | ||
| >>> model.hparams | ||
| "p1": 1 | ||
| "p2": abc | ||
| "p3": 3.14 | ||
| or string names or arguments from class ``__init__`` | ||
| ignore: an argument name or a list of argument names from | ||
| class ``__init__`` to be ignored | ||
| frame: a frame object. Default is None | ||
|
|
||
| Example:: | ||
| >>> class ManuallyArgsModel(LightningModule): | ||
| ... def __init__(self, arg1, arg2, arg3): | ||
| ... super().__init__() | ||
| ... # manually assign arguments | ||
| ... self.save_hyperparameters('arg1', 'arg3') | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = ManuallyArgsModel(1, 'abc', 3.14) | ||
| >>> model.hparams | ||
| "arg1": 1 | ||
| "arg3": 3.14 | ||
|
|
||
| >>> class AutomaticArgsModel(LightningModule): | ||
| ... def __init__(self, arg1, arg2, arg3): | ||
| ... super().__init__() | ||
| ... # equivalent automatic | ||
| ... self.save_hyperparameters() | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = AutomaticArgsModel(1, 'abc', 3.14) | ||
| >>> model.hparams | ||
| "arg1": 1 | ||
| "arg2": abc | ||
| "arg3": 3.14 | ||
|
|
||
| >>> class SingleArgModel(LightningModule): | ||
| ... def __init__(self, params): | ||
| ... super().__init__() | ||
| ... # manually assign single argument | ||
| ... self.save_hyperparameters(params) | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) | ||
| >>> model.hparams | ||
| "p1": 1 | ||
| "p2": abc | ||
| "p3": 3.14 | ||
|
|
||
| >>> class ManuallyArgsModel(LightningModule): | ||
| ... def __init__(self, arg1, arg2, arg3): | ||
| ... super().__init__() | ||
| ... # pass argument(s) to ignore as a string or in a list | ||
| ... self.save_hyperparameters(ignore='arg2') | ||
| ... def forward(self, *args, **kwargs): | ||
| ... ... | ||
| >>> model = ManuallyArgsModel(1, 'abc', 3.14) | ||
| >>> model.hparams | ||
| "arg1": 1 | ||
| "arg3": 3.14 | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| if not frame: | ||
| frame = inspect.currentframe().f_back | ||
| init_args = get_init_args(frame) | ||
| assert init_args, "failed to inspect the self init" | ||
|
|
||
| if ignore is not None: | ||
| if isinstance(ignore, str): | ||
| ignore = [ignore] | ||
| if isinstance(ignore, (list, tuple)): | ||
| ignore = [arg for arg in ignore if isinstance(arg, str)] | ||
| init_args = {k: v for k, v in init_args.items() if k not in ignore} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Borda: I am doing a check for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And also the other tests will fail, as the default is None :) |
||
|
|
||
| if not args: | ||
| # take all arguments | ||
| hp = init_args | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.