-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
argparse (removed)Related to argument parsing (argparse, Hydra, ...)Related to argument parsing (argparse, Hydra, ...)designIncludes a design discussionIncludes a design discussiondiscussionIn a discussion stageIn a discussion stagefeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task
Milestone
Description
🚀 Feature
Implement a CLI store/provider/registry for available LightningModules, LightningDataModules, and Callbacks.
Implementation
class LightningCLI:
store = Registry()
@classmethod
def register(cls, *args, **kwargs):
cls.store.register(*args, **kwargs)The registry idea is highly influenced by this implementation:
https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/registry.py
https://github.com/PyTorchLightning/lightning-flash/blob/master/tests/core/test_registry.py
Usage
# note: these `register` calls could be in different files
LightningCLI.register('my-callback', CallbackA)
LightningCLI.register('vision-model', VisionModel)
LightningCLI.register('mnist-data', MnistDataModule)
LightningCLI.register('cifar10-data', CIFAR10DataModule)
LightningCLI.register('imagenet-data', ImageNetDataModule)
# API might change - naming the model-data pair as an experiment
LightningCLI.register_experiment(VisionModel, MnistDataModule)
LightningCLI.register_experiment(VisionModel, CIFAR10DataModule)
LightningCLI.register_experiment(VisionModel, ImageneteDataModule)
# or with a decorator.
# `name` could be taken from the class or function
# `arg` could be inferred from the parent class
@LightningCLI.register(name='some_model', arg='model')
class SomeModel(LightningModule)
...
# in train.py
cli = LightningCLI()Console interaction
Basic help
python train.py -h callbacks: {
'my-callback': CallbackA,
}
models: {
'vision-model': VisionModel,
}
datamodules: {
'mnist-data': MnistDataModule,
'cifar10-data': CIFAR10DataModule,
'imagenet-data': ImageNetDataModule,
}
experiments: {
'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},
}Note that the help output should look nicer, showing raw dicts for simplicity
Filter by category
python train.py -h model=VisionModel
# alternatively
python train-py -h model='vision-model'callbacks: {
'my-callback': CallbackA,
}
datamodules: {
'mnist-data': MnistDataModule,
'cifar10-data': CIFAR10DataModule,
'imagenet-data': ImageNetDataModule,
}
experiments: {
'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},
}python train.py -h data=CIFAR10DataModulecallbacks: {
'my-callback': CallbackA,
}
models: {
'vision-model': VisionModel,
}
datamodules: {
'cifar10-data': CIFAR10DataModule,
}
experiments: {
'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
}Basic LightningCLI usage (similar to what's currently implemented)
python train.py -h model=VisionModel data=ImageNetDataModule
# same as
python train.py -h experiment='vision-imagenet'Metadata
Metadata
Assignees
Labels
argparse (removed)Related to argument parsing (argparse, Hydra, ...)Related to argument parsing (argparse, Hydra, ...)designIncludes a design discussionIncludes a design discussiondiscussionIn a discussion stageIn a discussion stagefeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task