|
22 | 22 | from typing_extensions import Literal |
23 | 23 |
|
24 | 24 | import pytorch_lightning as pl |
| 25 | +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE |
25 | 26 | from pytorch_lightning.utilities.warnings import rank_zero_warn |
26 | 27 |
|
| 28 | +if _OMEGACONF_AVAILABLE: |
| 29 | + from omegaconf.dictconfig import DictConfig |
| 30 | + |
27 | 31 |
|
28 | 32 | def str_to_bool_or_str(val: str) -> Union[str, bool]: |
29 | 33 | """Possibly convert a string representation of truth to bool. |
@@ -204,46 +208,57 @@ def save_hyperparameters( |
204 | 208 | obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None |
205 | 209 | ) -> None: |
206 | 210 | """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" |
207 | | - |
| 211 | + hparams_container_types = [Namespace, dict] |
| 212 | + if _OMEGACONF_AVAILABLE: |
| 213 | + hparams_container_types.append(DictConfig) |
| 214 | + # empty container |
208 | 215 | if len(args) == 1 and not isinstance(args, str) and not args[0]: |
209 | | - # args[0] is an empty container |
210 | 216 | return |
211 | | - |
212 | | - if not frame: |
213 | | - current_frame = inspect.currentframe() |
214 | | - # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available |
215 | | - if current_frame: |
216 | | - frame = current_frame.f_back |
217 | | - if not isinstance(frame, types.FrameType): |
218 | | - raise AttributeError("There is no `frame` available while being required.") |
219 | | - |
220 | | - if is_dataclass(obj): |
221 | | - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
222 | | - else: |
223 | | - init_args = get_init_args(frame) |
224 | | - assert init_args, "failed to inspect the obj init" |
225 | | - |
226 | | - if ignore is not None: |
227 | | - if isinstance(ignore, str): |
228 | | - ignore = [ignore] |
229 | | - if isinstance(ignore, (list, tuple)): |
230 | | - ignore = [arg for arg in ignore if isinstance(arg, str)] |
231 | | - init_args = {k: v for k, v in init_args.items() if k not in ignore} |
232 | | - |
233 | | - if not args: |
234 | | - # take all arguments |
235 | | - hp = init_args |
236 | | - obj._hparams_name = "kwargs" if hp else None |
| 217 | + # container |
| 218 | + elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): |
| 219 | + hp = args[0] |
| 220 | + obj._hparams_name = "hparams" |
| 221 | + obj._set_hparams(hp) |
| 222 | + obj._hparams_initial = copy.deepcopy(obj._hparams) |
| 223 | + return |
| 224 | + # non-container args parsing |
237 | 225 | else: |
238 | | - # take only listed arguments in `save_hparams` |
239 | | - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
240 | | - if len(isx_non_str) == 1: |
241 | | - hp = args[isx_non_str[0]] |
242 | | - cand_names = [k for k, v in init_args.items() if v == hp] |
243 | | - obj._hparams_name = cand_names[0] if cand_names else None |
| 226 | + if not frame: |
| 227 | + current_frame = inspect.currentframe() |
| 228 | + # inspect.currentframe() return type is Optional[types.FrameType] |
| 229 | + # current_frame.f_back called only if available |
| 230 | + if current_frame: |
| 231 | + frame = current_frame.f_back |
| 232 | + if not isinstance(frame, types.FrameType): |
| 233 | + raise AttributeError("There is no `frame` available while being required.") |
| 234 | + |
| 235 | + if is_dataclass(obj): |
| 236 | + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
| 237 | + else: |
| 238 | + init_args = get_init_args(frame) |
| 239 | + assert init_args, f"failed to inspect the obj init - {frame}" |
| 240 | + |
| 241 | + if ignore is not None: |
| 242 | + if isinstance(ignore, str): |
| 243 | + ignore = [ignore] |
| 244 | + if isinstance(ignore, (list, tuple, set)): |
| 245 | + ignore = [arg for arg in ignore if isinstance(arg, str)] |
| 246 | + init_args = {k: v for k, v in init_args.items() if k not in ignore} |
| 247 | + |
| 248 | + if not args: |
| 249 | + # take all arguments |
| 250 | + hp = init_args |
| 251 | + obj._hparams_name = "kwargs" if hp else None |
244 | 252 | else: |
245 | | - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
246 | | - obj._hparams_name = "kwargs" |
| 253 | + # take only listed arguments in `save_hparams` |
| 254 | + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
| 255 | + if len(isx_non_str) == 1: |
| 256 | + hp = args[isx_non_str[0]] |
| 257 | + cand_names = [k for k, v in init_args.items() if v == hp] |
| 258 | + obj._hparams_name = cand_names[0] if cand_names else None |
| 259 | + else: |
| 260 | + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
| 261 | + obj._hparams_name = "kwargs" |
247 | 262 |
|
248 | 263 | # `hparams` are expected here |
249 | 264 | if hp: |
|
0 commit comments