From ecc7970c64b7291a9d9be3d290b4c64a7d59c444 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 12 Sep 2022 08:28:28 +0200 Subject: [PATCH] Fix mypy errors in pytorch_lightning/cli.py. --- requirements/pytorch/extra.txt | 2 +- src/pytorch_lightning/cli.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index b547e7a62f2ed..5c38a286ef21d 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,7 @@ matplotlib>3.1, <3.5.3 omegaconf>=2.0.5, <2.3.0 hydra-core>=1.0.5, <1.3.0 -jsonargparse[signatures]>=4.12.0, <=4.12.0 +jsonargparse[signatures]>=4.12.0, <4.14.0 gcsfs>=2021.5.0, <2022.8.0 rich>=10.14.0, !=10.15.0.a, <13.0.0 protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index 82156c6b4ab90..875f9e00660ff 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -145,7 +145,7 @@ def add_optimizer_args( assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) - kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: @@ -170,7 +170,7 @@ def add_lr_scheduler_args( assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: @@ -436,6 +436,7 @@ def subcommands() -> Dict[str, Set[str]]: def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" + self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( @@ -444,6 +445,7 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No # register all subcommands in separate subcommand parsers under the main parser for subcommand in self.subcommands(): subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {})) + self._subcommand_parsers[subcommand] = subcommand_parser fn = getattr(trainer_class, subcommand) # extract the first line description in the docstring for the subcommand help message description = _get_short_description(fn) @@ -528,8 +530,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed - action_subcommand = self.parser._subcommands_action - return action_subcommand._name_parser_map[subcommand] + return self._subcommand_parsers[subcommand] @staticmethod def configure_optimizers( @@ -611,7 +612,7 @@ def get_automatic( # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) - def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: + def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default)