1414"""The LightningModule - an nn.Module with many additional features."""
1515
1616import collections
17- import copy
1817import inspect
1918import logging
2019import numbers
2120import os
2221import tempfile
23- import types
2422import uuid
2523from abc import ABC
26- from argparse import Namespace
2724from pathlib import Path
28- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Union
25+ from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
2926
3027import numpy as np
3128import torch
3835from pytorch_lightning .core .hooks import CheckpointHooks , DataHooks , ModelHooks
3936from pytorch_lightning .core .memory import ModelSummary
4037from pytorch_lightning .core .optimizer import LightningOptimizer
41- from pytorch_lightning .core .saving import ALLOWED_CONFIG_TYPES , ModelIO , PRIMITIVE_TYPES
38+ from pytorch_lightning .core .saving import ModelIO
4239from pytorch_lightning .trainer .connectors .logger_connector .fx_validator import FxValidator
4340from pytorch_lightning .utilities import rank_zero_deprecation , rank_zero_warn
4441from pytorch_lightning .utilities .apply_func import apply_to_collection , convert_to_tensors
4542from pytorch_lightning .utilities .cloud_io import get_filesystem
4643from pytorch_lightning .utilities .device_dtype_mixin import DeviceDtypeModuleMixin
4744from pytorch_lightning .utilities .distributed import distributed_available , sync_ddp
4845from pytorch_lightning .utilities .exceptions import MisconfigurationException
49- from pytorch_lightning .utilities .parsing import AttributeDict , collect_init_args , save_hyperparameters
46+ from pytorch_lightning .utilities .hparams_mixin import HyperparametersMixin
47+ from pytorch_lightning .utilities .parsing import collect_init_args
5048from pytorch_lightning .utilities .signature_utils import is_param_in_hook_signature
5149from pytorch_lightning .utilities .types import _METRIC_COLLECTION , EPOCH_OUTPUT , STEP_OUTPUT
5250from pytorch_lightning .utilities .warnings import WarningCache
5856class LightningModule (
5957 ABC ,
6058 DeviceDtypeModuleMixin ,
59+ HyperparametersMixin ,
6160 GradInformation ,
6261 ModelIO ,
6362 ModelHooks ,
@@ -70,8 +69,6 @@ class LightningModule(
7069 __jit_unused_properties__ = [
7170 "datamodule" ,
7271 "example_input_array" ,
73- "hparams" ,
74- "hparams_initial" ,
7572 "on_gpu" ,
7673 "current_epoch" ,
7774 "global_step" ,
@@ -82,7 +79,7 @@ class LightningModule(
8279 "automatic_optimization" ,
8380 "truncated_bptt_steps" ,
8481 "loaded_optimizer_states_dict" ,
85- ] + DeviceDtypeModuleMixin .__jit_unused_properties__
82+ ] + DeviceDtypeModuleMixin .__jit_unused_properties__ + HyperparametersMixin . __jit_unused_properties__
8683
8784 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
8885 super ().__init__ (* args , ** kwargs )
@@ -1832,92 +1829,6 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
18321829 parents_arguments .update (args )
18331830 return self_arguments , parents_arguments
18341831
1835- def save_hyperparameters (
1836- self ,
1837- * args ,
1838- ignore : Optional [Union [Sequence [str ], str ]] = None ,
1839- frame : Optional [types .FrameType ] = None
1840- ) -> None :
1841- """Save model arguments to the ``hparams`` attribute.
1842-
1843- Args:
1844- args: single object of type :class:`dict`, :class:`~argparse.Namespace`, `OmegaConf`
1845- or strings representing the argument names in ``__init__``.
1846- ignore: an argument name or a list of argument names in ``__init__`` to be ignored
1847- frame: a frame object. Default is ``None``.
1848-
1849- Example::
1850-
1851- >>> class ManuallyArgsModel(LightningModule):
1852- ... def __init__(self, arg1, arg2, arg3):
1853- ... super().__init__()
1854- ... # manually assign arguments
1855- ... self.save_hyperparameters('arg1', 'arg3')
1856- ... def forward(self, *args, **kwargs):
1857- ... ...
1858- >>> model = ManuallyArgsModel(1, 'abc', 3.14)
1859- >>> model.hparams
1860- "arg1": 1
1861- "arg3": 3.14
1862-
1863- >>> class AutomaticArgsModel(LightningModule):
1864- ... def __init__(self, arg1, arg2, arg3):
1865- ... super().__init__()
1866- ... # equivalent automatic
1867- ... self.save_hyperparameters()
1868- ... def forward(self, *args, **kwargs):
1869- ... ...
1870- >>> model = AutomaticArgsModel(1, 'abc', 3.14)
1871- >>> model.hparams
1872- "arg1": 1
1873- "arg2": abc
1874- "arg3": 3.14
1875-
1876- >>> class SingleArgModel(LightningModule):
1877- ... def __init__(self, params):
1878- ... super().__init__()
1879- ... # manually assign single argument
1880- ... self.save_hyperparameters(params)
1881- ... def forward(self, *args, **kwargs):
1882- ... ...
1883- >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
1884- >>> model.hparams
1885- "p1": 1
1886- "p2": abc
1887- "p3": 3.14
1888-
1889- >>> class ManuallyArgsModel(LightningModule):
1890- ... def __init__(self, arg1, arg2, arg3):
1891- ... super().__init__()
1892- ... # pass argument(s) to ignore as a string or in a list
1893- ... self.save_hyperparameters(ignore='arg2')
1894- ... def forward(self, *args, **kwargs):
1895- ... ...
1896- >>> model = ManuallyArgsModel(1, 'abc', 3.14)
1897- >>> model.hparams
1898- "arg1": 1
1899- "arg3": 3.14
1900- """
1901- # the frame needs to be created in this file.
1902- if not frame :
1903- frame = inspect .currentframe ().f_back
1904- save_hyperparameters (self , * args , ignore = ignore , frame = frame )
1905-
1906- def _set_hparams (self , hp : Union [dict , Namespace , str ]) -> None :
1907- if isinstance (hp , Namespace ):
1908- hp = vars (hp )
1909- if isinstance (hp , dict ):
1910- hp = AttributeDict (hp )
1911- elif isinstance (hp , PRIMITIVE_TYPES ):
1912- raise ValueError (f"Primitives { PRIMITIVE_TYPES } are not allowed." )
1913- elif not isinstance (hp , ALLOWED_CONFIG_TYPES ):
1914- raise ValueError (f"Unsupported config type of { type (hp )} ." )
1915-
1916- if isinstance (hp , dict ) and isinstance (self .hparams , dict ):
1917- self .hparams .update (hp )
1918- else :
1919- self ._hparams = hp
1920-
19211832 @torch .no_grad ()
19221833 def to_onnx (
19231834 self ,
@@ -2049,27 +1960,6 @@ def to_torchscript(
20491960
20501961 return torchscript_module
20511962
2052- @property
2053- def hparams (self ) -> Union [AttributeDict , dict , Namespace ]:
2054- """
2055- The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
2056- For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
2057- """
2058- if not hasattr (self , "_hparams" ):
2059- self ._hparams = AttributeDict ()
2060- return self ._hparams
2061-
2062- @property
2063- def hparams_initial (self ) -> AttributeDict :
2064- """
2065- The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only.
2066- Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`.
2067- """
2068- if not hasattr (self , "_hparams_initial" ):
2069- return AttributeDict ()
2070- # prevent any change
2071- return copy .deepcopy (self ._hparams_initial )
2072-
20731963 @property
20741964 def model_size (self ) -> float :
20751965 """
0 commit comments