-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Add a save_hyperparameters function to LightningDataModule and let the Trainer log it together with the model's hyperparameters.
Motivation
DataModules are a great way to decouple the model code from the data it runs on.
People using datasets where pre-processing is not as fixed as in the common benchmarks need their DataModules to be configurable.
Examples would be:
- size of sliding windows
- maximum sequence length
- type of scaling (min-max, standardization, etc.)
Logging these hyperparameters is just as important for evaluating your model performance as the model's hyperparameters.
Therefore, they should be automatically logged by the trainer, too.
Pitch
You are still searching for the perfect way to pre-process your data for maximum performance?
Keep all your efforts in order by logging the hyperparameters of your LightningDataModule.
Alternatives
Right now, I manually define a hyperparameter dictionary as a member of my DataModule.
Afterward, I call update on the hparams property of my LightningModule.
This is pretty low-level code at the top-level of my script.
Additional context
Code example for current solution:
class MyDataModule(pl.LightningDataModule):
def __init__(self,
fd,
batch_size,
max_rul=125,
window_size=30,
percent_fail_runs=None,
percent_broken=None,
feature_select=None):
...
self.hparams = {'fd': self.fd,
'batch_size': self.batch_size,
'window_size': self.window_size,
'max_rul': self.max_rul,
'percent_broken': self.percent_broken,
'percent_fail_runs': self.percent_fail_runs}
...
data = datasets.MyDataModule(...)
model.hparams.update(data.hparams)