Skip to content

Commit cd67124

Browse files
authored
Fix mypy errors in pytorch_lightning/cli.py (#14653)
1 parent b679fc2 commit cd67124

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.5.3
66
omegaconf>=2.0.5, <2.3.0
77
hydra-core>=1.0.5, <1.3.0
8-
jsonargparse[signatures]>=4.12.0, <=4.12.0
8+
jsonargparse[signatures]>=4.12.0, <4.14.0
99
gcsfs>=2021.5.0, <2022.8.0
1010
rich>=10.14.0, !=10.15.0.a, <13.0.0
1111
protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure

src/pytorch_lightning/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def add_optimizer_args(
145145
assert all(issubclass(o, Optimizer) for o in optimizer_class)
146146
else:
147147
assert issubclass(optimizer_class, Optimizer)
148-
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
148+
kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
149149
if isinstance(optimizer_class, tuple):
150150
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
151151
else:
@@ -170,7 +170,7 @@ def add_lr_scheduler_args(
170170
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
171171
else:
172172
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
173-
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
173+
kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
174174
if isinstance(lr_scheduler_class, tuple):
175175
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
176176
else:
@@ -436,6 +436,7 @@ def subcommands() -> Dict[str, Set[str]]:
436436

437437
def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
438438
"""Adds subcommands to the input parser."""
439+
self._subcommand_parsers: Dict[str, LightningArgumentParser] = {}
439440
parser_subcommands = parser.add_subcommands()
440441
# the user might have passed a builder function
441442
trainer_class = (
@@ -444,6 +445,7 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No
444445
# register all subcommands in separate subcommand parsers under the main parser
445446
for subcommand in self.subcommands():
446447
subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {}))
448+
self._subcommand_parsers[subcommand] = subcommand_parser
447449
fn = getattr(trainer_class, subcommand)
448450
# extract the first line description in the docstring for the subcommand help message
449451
description = _get_short_description(fn)
@@ -528,8 +530,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
528530
if subcommand is None:
529531
return self.parser
530532
# return the subcommand parser for the subcommand passed
531-
action_subcommand = self.parser._subcommands_action
532-
return action_subcommand._name_parser_map[subcommand]
533+
return self._subcommand_parsers[subcommand]
533534

534535
@staticmethod
535536
def configure_optimizers(
@@ -611,7 +612,7 @@ def get_automatic(
611612
# override the existing method
612613
self.model.configure_optimizers = MethodType(fn, self.model)
613614

614-
def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
615+
def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any:
615616
"""Utility to get a config value which might be inside a subcommand."""
616617
return config.get(str(self.subcommand), config).get(key, default)
617618

0 commit comments

Comments
 (0)