-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
With yaml config file for LightningCLI, self.save_hyperparameters() in __init__ of the model and datamodule mistakenly saves a dict containing keys like class_path and init_args.
This problems appears in version 2.3.0, but version 2.2.5 works correctly.
What version are you seeing the problem on?
2.3.0
How to reproduce the bug
config.yaml
ckpt_path: null
seed_everything: 0
model:
class_path: model.Model
init_args:
learning_rate: 1e-2
data:
class_path: datamodule.DataModule
init_args:
data_dir: data
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: null
fast_dev_run: false
max_epochs: 100
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 10
limit_test_batches: null
limit_predict_batches: null
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: lightning_logs
name: normalized
callbacks:
class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 5
monitor: valid_loss
filename: "{epoch}-{step}-{valid_loss:.8f}"
overfit_batches: 0.0
val_check_interval: 50
check_val_every_n_epoch: 1
num_sanity_val_steps: null
log_every_n_steps: 50
enable_checkpointing: null
enable_progress_bar: null
enable_model_summary: null
accumulate_grad_batches: 1
gradient_clip_val: null
gradient_clip_algorithm: null
deterministic: true
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: true
reload_dataloaders_every_n_epochs: 0
default_root_dir: nullmodel.py
import torch
from torch import nn
import torch.nn.functional as F
import lightning as pl
class Model(pl.LightningModule):
def __init__(self, learning_rate: float):
super().__init__()
print()
print("Model:")
print(f"learning_rate: {learning_rate}")
## This outputs correctly.
self.save_hyperparameters()
print(self.hparams)
## This outputs:
# "_instantiator": lightning.pytorch.cli.instantiate_module
# "class_path": model.Model
# "init_args": {'learning_rate': 0.01}datamodule.py
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from dataset import KaptchaDataset
from transform import Transform
class DataModule(LightningDataModule):
def __init__(self, data_dir: str):
super().__init__()
self.save_hyperparameters()
print()
print("DataModule:")
print(self.hparams)
## This outputs
# "_instantiator": lightning.pytorch.cli.instantiate_module
# "class_path": datamodule.DataModule
# "init_args": {'data_dir': 'data'}main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from model import Model
from datamodule import DataModule
def cli_main():
cli = LightningCLI()
if __name__ == "__main__":
cli_main()Run python main.py fit --config config.yaml
### Environment
<details>
<summary>Current environment</summary>
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
#- Running environment of LightningApp (e.g. local, cloud):
</details>
cc @carmocca @mauvilsa