-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Is your feature request related to a problem? Please describe.
This is half feedback/feature request. Maybe our approach is not right be here is what we felt when trying this awesome library:
We would like to use a LightningModule in our pipelines, but we have some constraints which makes this difficult.
We have an experiment framework where we can register models (eg a LightningModule) by instantiating them. Then the framework trains the various model using some train/val/test data which is specified at runtime and generates performance reports.
Pseudo code:
class TorchModel:
def fit(x_train, y_train, x_val, y_val):
trainer = Trainer(...)
trainer.fit(self.module)
models = [
ModelA(...),
TorchModel(module=CoolModel()), # TorchModel is actually a wrapper which exposes a common interface to Sklearn/Keras/Torch models
]
experiment_runner = Runner(models)
experiment_runner.run(train_dataset, val_dataset, test_dataset)Or Uber's Ludwig would do:
from ludwig.api import LudwigModel
# train a model
model_definition = {...}
model = LudwigModel(model_definition)
train_stats = model.train(training_dataframe)Describe the solution you'd like
For us, the datasets / input tensors don't belong to the definition of the module. We understand that it improves reproducibility but it may reduce portability of models
They probably should be provided to the trainer at instantiation:
Trainer(train_dataset=..., val_dataset=...)
# And maybe
class CoolModel(pl.LightningModule):
...
@pl.data_loader
def tng_dataloader(self, dataset):
return DataLoader(dataset, batch_size=32)
...
Describe alternatives you've considered
A temporary solution could be:
class TorchModel:
def fit(x_train, y_train, x_val, y_val):
self.module.set_train_dataset(x_train, y_train)
self.module.set_val_dataset(x_val, y_val)
trainer = Trainer(...)
trainer.fit(self.module)
Additional context
Thanks for creating this library, this makes pytorch so much easier to use!