Skip to content

save_hyperparameters attempts to parse unspecified args when a namespace is specified #8948

@s-rog

Description

@s-rog

🐛 Bug

save_hyperparameters attempts to parse unspecified args when a namespace is specified

To Reproduce

import argparse

import pandas as pd
import pytorch_lightning as pl

class plmodule(pl.LightningModule):
    def __init__(self, hp, df):
        super().__init__()
        self.save_hyperparameters(hp)


hp = argparse.Namespace(**{"hello": "world"})
df = pd.DataFrame(["hello", "world"])
x = plmodule(hp, df)

Trace

Expand
ValueError                                Traceback (most recent call last)
<ipython-input-2-8ea9fd40a08a> in <module>
     12 hp = argparse.Namespace(**{"hello": "world"})
     13 df = pd.DataFrame(["hello", "world"])
---> 14 x = plmodule(hp, df)

<ipython-input-2-8ea9fd40a08a> in __init__(self, hp, df)
      7     def __init__(self, hp, df):
      8         super().__init__()
----> 9         self.save_hyperparameters(hp)
     10 
     11 

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/mixins/hparams_mixin.py in save_hyperparameters(self, ignore, frame, logger, *args)
    103         if not frame:
    104             frame = inspect.currentframe().f_back
--> 105         save_hyperparameters(self, *args, ignore=ignore, frame=frame)
    106 
    107     def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in save_hyperparameters(obj, ignore, frame, *args)
    240         if len(isx_non_str) == 1:
    241             hp = args[isx_non_str[0]]
--> 242             cand_names = [k for k, v in init_args.items() if v == hp]
    243             obj._hparams_name = cand_names[0] if cand_names else None
    244         else:

/opt/conda/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in <listcomp>(.0)
    240         if len(isx_non_str) == 1:
    241             hp = args[isx_non_str[0]]
--> 242             cand_names = [k for k, v in init_args.items() if v == hp]
    243             obj._hparams_name = cand_names[0] if cand_names else None
    244         else:

/opt/conda/lib/python3.8/site-packages/pandas/core/generic.py in __nonzero__(self)
   1532     @final
   1533     def __nonzero__(self):
-> 1534         raise ValueError(
   1535             f"The truth value of a {type(self).__name__} is ambiguous. "
   1536             "Use a.empty, a.bool(), a.item(), a.any() or a.all()."

ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

Expected behavior

module is created with hparams "hello": world

Environment

  • PyTorch Lightning Version: reproduced in 1.3.7, 1.3.8, 1.4.2
  • PyTorch Version: 1.8.0a0+52ea372
  • Python version: 3.8

Additional context

Seems to be an old bug and not a new regression, setting df = 1 results in the correct behavior, but save_hyperparameters shouldnt be accessing other args when a namespace is already specified.

Encountered this as I was trying to set hparams in the datamodule as changed in 1.4 #3792

Relevant code block:
https://github.com/PyTorchLightning/pytorch-lightning/blob/938a191406fff5f51fba03fcf824f22d8d23c2e0/pytorch_lightning/utilities/parsing.py#L203

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpriority: 1Medium priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions