|
22 | 22 | from typing_extensions import Literal |
23 | 23 |
|
24 | 24 | import pytorch_lightning as pl |
25 | | -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE |
26 | 25 | from pytorch_lightning.utilities.warnings import rank_zero_warn |
27 | 26 |
|
28 | | -if _OMEGACONF_AVAILABLE: |
29 | | - from omegaconf.dictconfig import DictConfig |
30 | | - |
31 | 27 |
|
32 | 28 | def str_to_bool_or_str(val: str) -> Union[str, bool]: |
33 | 29 | """Possibly convert a string representation of truth to bool. |
@@ -208,57 +204,46 @@ def save_hyperparameters( |
208 | 204 | obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None |
209 | 205 | ) -> None: |
210 | 206 | """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" |
211 | | - hparams_container_types = [Namespace, dict] |
212 | | - if _OMEGACONF_AVAILABLE: |
213 | | - hparams_container_types.append(DictConfig) |
214 | | - # empty container |
| 207 | + |
215 | 208 | if len(args) == 1 and not isinstance(args, str) and not args[0]: |
| 209 | + # args[0] is an empty container |
216 | 210 | return |
217 | | - # container |
218 | | - elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): |
219 | | - hp = args[0] |
220 | | - obj._hparams_name = "hparams" |
221 | | - obj._set_hparams(hp) |
222 | | - obj._hparams_initial = copy.deepcopy(obj._hparams) |
223 | | - return |
224 | | - # non-container args parsing |
| 211 | + |
| 212 | + if not frame: |
| 213 | + current_frame = inspect.currentframe() |
| 214 | + # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available |
| 215 | + if current_frame: |
| 216 | + frame = current_frame.f_back |
| 217 | + if not isinstance(frame, types.FrameType): |
| 218 | + raise AttributeError("There is no `frame` available while being required.") |
| 219 | + |
| 220 | + if is_dataclass(obj): |
| 221 | + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
225 | 222 | else: |
226 | | - if not frame: |
227 | | - current_frame = inspect.currentframe() |
228 | | - # inspect.currentframe() return type is Optional[types.FrameType] |
229 | | - # current_frame.f_back called only if available |
230 | | - if current_frame: |
231 | | - frame = current_frame.f_back |
232 | | - if not isinstance(frame, types.FrameType): |
233 | | - raise AttributeError("There is no `frame` available while being required.") |
234 | | - |
235 | | - if is_dataclass(obj): |
236 | | - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
237 | | - else: |
238 | | - init_args = get_init_args(frame) |
239 | | - assert init_args, f"failed to inspect the obj init - {frame}" |
240 | | - |
241 | | - if ignore is not None: |
242 | | - if isinstance(ignore, str): |
243 | | - ignore = [ignore] |
244 | | - if isinstance(ignore, (list, tuple, set)): |
245 | | - ignore = [arg for arg in ignore if isinstance(arg, str)] |
246 | | - init_args = {k: v for k, v in init_args.items() if k not in ignore} |
247 | | - |
248 | | - if not args: |
249 | | - # take all arguments |
250 | | - hp = init_args |
251 | | - obj._hparams_name = "kwargs" if hp else None |
| 223 | + init_args = get_init_args(frame) |
| 224 | + assert init_args, "failed to inspect the obj init" |
| 225 | + |
| 226 | + if ignore is not None: |
| 227 | + if isinstance(ignore, str): |
| 228 | + ignore = [ignore] |
| 229 | + if isinstance(ignore, (list, tuple)): |
| 230 | + ignore = [arg for arg in ignore if isinstance(arg, str)] |
| 231 | + init_args = {k: v for k, v in init_args.items() if k not in ignore} |
| 232 | + |
| 233 | + if not args: |
| 234 | + # take all arguments |
| 235 | + hp = init_args |
| 236 | + obj._hparams_name = "kwargs" if hp else None |
| 237 | + else: |
| 238 | + # take only listed arguments in `save_hparams` |
| 239 | + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
| 240 | + if len(isx_non_str) == 1: |
| 241 | + hp = args[isx_non_str[0]] |
| 242 | + cand_names = [k for k, v in init_args.items() if v == hp] |
| 243 | + obj._hparams_name = cand_names[0] if cand_names else None |
252 | 244 | else: |
253 | | - # take only listed arguments in `save_hparams` |
254 | | - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
255 | | - if len(isx_non_str) == 1: |
256 | | - hp = args[isx_non_str[0]] |
257 | | - cand_names = [k for k, v in init_args.items() if v == hp] |
258 | | - obj._hparams_name = cand_names[0] if cand_names else None |
259 | | - else: |
260 | | - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
261 | | - obj._hparams_name = "kwargs" |
| 245 | + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
| 246 | + obj._hparams_name = "kwargs" |
262 | 247 |
|
263 | 248 | # `hparams` are expected here |
264 | 249 | if hp: |
|
0 commit comments