Skip to content

Commit ca21da4

Browse files
tchatoncarmoccaawaelchli
authored
Move save_hyperparameters to its own function (#7119)
* move hyper_parameters * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/utilities/parsing.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve flake8 * update * resolve tests * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 1a68dcd commit ca21da4

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
4343
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
4444
from pytorch_lightning.utilities.exceptions import MisconfigurationException
45-
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
45+
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
4646
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
4747

4848
log = logging.getLogger(__name__)
@@ -1645,38 +1645,10 @@ class ``__init__`` to be ignored
16451645
"arg1": 1
16461646
"arg3": 3.14
16471647
"""
1648+
# the frame needs to be created in this file.
16481649
if not frame:
16491650
frame = inspect.currentframe().f_back
1650-
init_args = get_init_args(frame)
1651-
assert init_args, "failed to inspect the self init"
1652-
1653-
if ignore is not None:
1654-
if isinstance(ignore, str):
1655-
ignore = [ignore]
1656-
if isinstance(ignore, (list, tuple)):
1657-
ignore = [arg for arg in ignore if isinstance(arg, str)]
1658-
init_args = {k: v for k, v in init_args.items() if k not in ignore}
1659-
1660-
if not args:
1661-
# take all arguments
1662-
hp = init_args
1663-
self._hparams_name = "kwargs" if hp else None
1664-
else:
1665-
# take only listed arguments in `save_hparams`
1666-
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
1667-
if len(isx_non_str) == 1:
1668-
hp = args[isx_non_str[0]]
1669-
cand_names = [k for k, v in init_args.items() if v == hp]
1670-
self._hparams_name = cand_names[0] if cand_names else None
1671-
else:
1672-
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
1673-
self._hparams_name = "kwargs"
1674-
1675-
# `hparams` are expected here
1676-
if hp:
1677-
self._set_hparams(hp)
1678-
# make deep copy so there is not other runtime changes reflected
1679-
self._hparams_initial = copy.deepcopy(self._hparams)
1651+
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
16801652

16811653
def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
16821654
if isinstance(hp, Namespace):

pytorch_lightning/utilities/parsing.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import copy
1515
import inspect
1616
import pickle
17+
import types
1718
from argparse import Namespace
18-
from typing import Dict, Tuple, Union
19+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
1920

2021
from pytorch_lightning.utilities import rank_zero_warn
2122

@@ -161,6 +162,47 @@ def flatten_dict(source, result=None):
161162
return result
162163

163164

165+
def save_hyperparameters(
166+
obj: Any,
167+
*args,
168+
ignore: Optional[Union[Sequence[str], str]] = None,
169+
frame: Optional[types.FrameType] = None
170+
) -> None:
171+
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
172+
if not frame:
173+
frame = inspect.currentframe().f_back
174+
init_args = get_init_args(frame)
175+
assert init_args, "failed to inspect the obj init"
176+
177+
if ignore is not None:
178+
if isinstance(ignore, str):
179+
ignore = [ignore]
180+
if isinstance(ignore, (list, tuple)):
181+
ignore = [arg for arg in ignore if isinstance(arg, str)]
182+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
183+
184+
if not args:
185+
# take all arguments
186+
hp = init_args
187+
obj._hparams_name = "kwargs" if hp else None
188+
else:
189+
# take only listed arguments in `save_hparams`
190+
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
191+
if len(isx_non_str) == 1:
192+
hp = args[isx_non_str[0]]
193+
cand_names = [k for k, v in init_args.items() if v == hp]
194+
obj._hparams_name = cand_names[0] if cand_names else None
195+
else:
196+
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
197+
obj._hparams_name = "kwargs"
198+
199+
# `hparams` are expected here
200+
if hp:
201+
obj._set_hparams(hp)
202+
# make deep copy so there is not other runtime changes reflected
203+
obj._hparams_initial = copy.deepcopy(obj._hparams)
204+
205+
164206
class AttributeDict(Dict):
165207
"""Extended dictionary accesisable with dot notation.
166208

0 commit comments

Comments
 (0)