@@ -89,6 +89,8 @@ def __str__(self) -> str:
8989CALLBACK_REGISTRY = _Registry ()
9090CALLBACK_REGISTRY .register_classes (pl .callbacks , pl .callbacks .Callback )
9191
92+ MODEL_REGISTRY = _Registry ()
93+
9294
9395class LightningArgumentParser (ArgumentParser ):
9496 """Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
@@ -147,7 +149,7 @@ def add_lightning_class_args(
147149 if issubclass (lightning_class , Callback ):
148150 self .callback_keys .append (nested_key )
149151 if subclass_mode :
150- return self .add_subclass_arguments (lightning_class , nested_key , required = True )
152+ return self .add_subclass_arguments (lightning_class , nested_key , fail_untyped = False , required = True )
151153 return self .add_class_arguments (
152154 lightning_class , nested_key , fail_untyped = False , instantiate = not issubclass (lightning_class , Trainer )
153155 )
@@ -385,7 +387,7 @@ class LightningCLI:
385387
386388 def __init__ (
387389 self ,
388- model_class : Union [Type [LightningModule ], Callable [..., LightningModule ]],
390+ model_class : Optional [ Union [Type [LightningModule ], Callable [..., LightningModule ]]] = None ,
389391 datamodule_class : Optional [Union [Type [LightningDataModule ], Callable [..., LightningDataModule ]]] = None ,
390392 save_config_callback : Optional [Type [SaveConfigCallback ]] = SaveConfigCallback ,
391393 save_config_filename : str = "config.yaml" ,
@@ -413,8 +415,9 @@ def __init__(
413415 .. warning:: ``LightningCLI`` is in beta and subject to change.
414416
415417 Args:
416- model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable
417- which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called.
418+ model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a
419+ callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when
420+ called. If ``None``, you can pass a registered model with ``--model=MyModel``.
418421 datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a
419422 callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when
420423 called.
@@ -439,17 +442,20 @@ def __init__(
439442 run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
440443 method. If set to ``False``, the trainer and model classes will be instantiated only.
441444 """
442- self .model_class = model_class
443445 self .datamodule_class = datamodule_class
444446 self .save_config_callback = save_config_callback
445447 self .save_config_filename = save_config_filename
446448 self .save_config_overwrite = save_config_overwrite
447449 self .trainer_class = trainer_class
448450 self .trainer_defaults = trainer_defaults or {}
449451 self .seed_everything_default = seed_everything_default
450- self .subclass_mode_model = subclass_mode_model
451452 self .subclass_mode_data = subclass_mode_data
452453
454+ self .model_class = model_class
455+ # used to differentiate between the original value and the processed value
456+ self ._model_class = model_class or LightningModule
457+ self .subclass_mode_model = (model_class is None ) or subclass_mode_model
458+
453459 main_kwargs , subparser_kwargs = self ._setup_parser_kwargs (
454460 parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
455461 {"description" : description , "env_prefix" : env_prefix , "default_env" : env_parse },
@@ -509,7 +515,12 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
509515 parser .set_choices ("trainer.callbacks" , CALLBACK_REGISTRY .classes , is_list = True )
510516 trainer_defaults = {"trainer." + k : v for k , v in self .trainer_defaults .items () if k != "callbacks" }
511517 parser .set_defaults (trainer_defaults )
512- parser .add_lightning_class_args (self .model_class , "model" , subclass_mode = self .subclass_mode_model )
518+
519+ parser .add_lightning_class_args (self ._model_class , "model" , subclass_mode = self .subclass_mode_model )
520+ if self .model_class is None and MODEL_REGISTRY :
521+ # did not pass a model and there are models registered
522+ parser .set_choices ("model" , MODEL_REGISTRY .classes )
523+
513524 if self .datamodule_class is not None :
514525 parser .add_lightning_class_args (self .datamodule_class , "data" , subclass_mode = self .subclass_mode_data )
515526
0 commit comments