@@ -145,7 +145,7 @@ def add_optimizer_args(
145145 assert all (issubclass (o , Optimizer ) for o in optimizer_class )
146146 else :
147147 assert issubclass (optimizer_class , Optimizer )
148- kwargs = {"instantiate" : False , "fail_untyped" : False , "skip" : {"params" }}
148+ kwargs : Dict [ str , Any ] = {"instantiate" : False , "fail_untyped" : False , "skip" : {"params" }}
149149 if isinstance (optimizer_class , tuple ):
150150 self .add_subclass_arguments (optimizer_class , nested_key , ** kwargs )
151151 else :
@@ -170,7 +170,7 @@ def add_lr_scheduler_args(
170170 assert all (issubclass (o , LRSchedulerTypeTuple ) for o in lr_scheduler_class )
171171 else :
172172 assert issubclass (lr_scheduler_class , LRSchedulerTypeTuple )
173- kwargs = {"instantiate" : False , "fail_untyped" : False , "skip" : {"optimizer" }}
173+ kwargs : Dict [ str , Any ] = {"instantiate" : False , "fail_untyped" : False , "skip" : {"optimizer" }}
174174 if isinstance (lr_scheduler_class , tuple ):
175175 self .add_subclass_arguments (lr_scheduler_class , nested_key , ** kwargs )
176176 else :
@@ -436,6 +436,7 @@ def subcommands() -> Dict[str, Set[str]]:
436436
437437 def _add_subcommands (self , parser : LightningArgumentParser , ** kwargs : Any ) -> None :
438438 """Adds subcommands to the input parser."""
439+ self ._subcommand_parsers : Dict [str , LightningArgumentParser ] = {}
439440 parser_subcommands = parser .add_subcommands ()
440441 # the user might have passed a builder function
441442 trainer_class = (
@@ -444,6 +445,7 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No
444445 # register all subcommands in separate subcommand parsers under the main parser
445446 for subcommand in self .subcommands ():
446447 subcommand_parser = self ._prepare_subcommand_parser (trainer_class , subcommand , ** kwargs .get (subcommand , {}))
448+ self ._subcommand_parsers [subcommand ] = subcommand_parser
447449 fn = getattr (trainer_class , subcommand )
448450 # extract the first line description in the docstring for the subcommand help message
449451 description = _get_short_description (fn )
@@ -528,8 +530,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
528530 if subcommand is None :
529531 return self .parser
530532 # return the subcommand parser for the subcommand passed
531- action_subcommand = self .parser ._subcommands_action
532- return action_subcommand ._name_parser_map [subcommand ]
533+ return self ._subcommand_parsers [subcommand ]
533534
534535 @staticmethod
535536 def configure_optimizers (
@@ -611,7 +612,7 @@ def get_automatic(
611612 # override the existing method
612613 self .model .configure_optimizers = MethodType (fn , self .model )
613614
614- def _get (self , config : Dict [ str , Any ] , key : str , default : Optional [Any ] = None ) -> Any :
615+ def _get (self , config : Namespace , key : str , default : Optional [Any ] = None ) -> Any :
615616 """Utility to get a config value which might be inside a subcommand."""
616617 return config .get (str (self .subcommand ), config ).get (key , default )
617618
0 commit comments