Skip to content

Extend the LightningCLI to register models, datamodules, and callbacks. #7250

@tchaton

Description

@tchaton

🚀 Feature

Implement a CLI store/provider/registry for available LightningModules, LightningDataModules, and Callbacks.

Implementation

class LightningCLI:
    store = Registry()

    @classmethod
    def register(cls, *args, **kwargs):
        cls.store.register(*args, **kwargs)

The registry idea is highly influenced by this implementation:

https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/core/registry.py
https://github.com/PyTorchLightning/lightning-flash/blob/master/tests/core/test_registry.py

Usage

# note: these `register` calls could be in different files
LightningCLI.register('my-callback', CallbackA)

LightningCLI.register('vision-model', VisionModel)

LightningCLI.register('mnist-data', MnistDataModule)
LightningCLI.register('cifar10-data', CIFAR10DataModule)
LightningCLI.register('imagenet-data', ImageNetDataModule)

# API might change - naming the model-data pair as an experiment
LightningCLI.register_experiment(VisionModel, MnistDataModule)
LightningCLI.register_experiment(VisionModel, CIFAR10DataModule)
LightningCLI.register_experiment(VisionModel, ImageneteDataModule)

# or with a decorator.
# `name` could be taken from the class or function
# `arg` could be inferred from the parent class
@LightningCLI.register(name='some_model', arg='model')
class SomeModel(LightningModule)
    ...

# in train.py
cli = LightningCLI()

Console interaction

Basic help

python train.py -h 
callbacks: {
    'my-callback': CallbackA,
}
models: {
    'vision-model': VisionModel,
}
datamodules: {
    'mnist-data': MnistDataModule,
    'cifar10-data': CIFAR10DataModule,
    'imagenet-data': ImageNetDataModule,
}
experiments: {
    'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
    'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},  
}

Note that the help output should look nicer, showing raw dicts for simplicity

Filter by category

python train.py -h model=VisionModel
# alternatively
python train-py -h model='vision-model'
callbacks: {
    'my-callback': CallbackA,
}
datamodules: {
    'mnist-data': MnistDataModule,
    'cifar10-data': CIFAR10DataModule,
    'imagenet-data': ImageNetDataModule,
}
experiments: {
    'vision-mnist': {'model': VisionModel, 'data': MnistDataModule},
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
    'vision-imagenet': {'model': VisionModel, 'data': ImageNetDataModule},  
}
python train.py -h data=CIFAR10DataModule
callbacks: {
    'my-callback': CallbackA,
}
models: {
    'vision-model': VisionModel,
}
datamodules: {
    'cifar10-data': CIFAR10DataModule,
}
experiments: {
    'vision-cifar10': {'model': VisionModel, 'data': CIFAR10DataModule},
}

Basic LightningCLI usage (similar to what's currently implemented)

python train.py -h model=VisionModel data=ImageNetDataModule
# same as
python train.py -h experiment='vision-imagenet'

Metadata

Metadata

Assignees

No one assigned

    Labels

    argparse (removed)Related to argument parsing (argparse, Hydra, ...)designIncludes a design discussiondiscussionIn a discussion stagefeatureIs an improvement or enhancementhelp wantedOpen to be worked onpriority: 0High priority task

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions