Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.5.3
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.12.0, <=4.12.0
jsonargparse[signatures]>=4.12.0, <4.14.0
gcsfs>=2021.5.0, <2022.8.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
protobuf<=3.20.1 # strict # an extra is updating protobuf, this pin prevents TensorBoard failure
11 changes: 6 additions & 5 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def add_optimizer_args(
assert all(issubclass(o, Optimizer) for o in optimizer_class)
else:
assert issubclass(optimizer_class, Optimizer)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
else:
Expand All @@ -170,7 +170,7 @@ def add_lr_scheduler_args(
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
else:
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
else:
Expand Down Expand Up @@ -436,6 +436,7 @@ def subcommands() -> Dict[str, Set[str]]:

def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
"""Adds subcommands to the input parser."""
self._subcommand_parsers: Dict[str, LightningArgumentParser] = {}
parser_subcommands = parser.add_subcommands()
# the user might have passed a builder function
trainer_class = (
Expand All @@ -444,6 +445,7 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No
# register all subcommands in separate subcommand parsers under the main parser
for subcommand in self.subcommands():
subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **kwargs.get(subcommand, {}))
self._subcommand_parsers[subcommand] = subcommand_parser
fn = getattr(trainer_class, subcommand)
# extract the first line description in the docstring for the subcommand help message
description = _get_short_description(fn)
Expand Down Expand Up @@ -528,8 +530,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
action_subcommand = self.parser._subcommands_action
return action_subcommand._name_parser_map[subcommand]
return self._subcommand_parsers[subcommand]

@staticmethod
def configure_optimizers(
Expand Down Expand Up @@ -611,7 +612,7 @@ def get_automatic(
# override the existing method
self.model.configure_optimizers = MethodType(fn, self.model)

def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any:
def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any:
"""Utility to get a config value which might be inside a subcommand."""
return config.get(str(self.subcommand), config).get(key, default)

Expand Down