Skip to content

With yaml config file for LightningCLI, self.save_hyperparameters() behavior abnormal #19977

@t4rf9

Description

@t4rf9

Bug description

With yaml config file for LightningCLI, self.save_hyperparameters() in __init__ of the model and datamodule mistakenly saves a dict containing keys like class_path and init_args.

This problems appears in version 2.3.0, but version 2.2.5 works correctly.

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

config.yaml

ckpt_path: null
seed_everything: 0
model:
  class_path: model.Model
  init_args:
    learning_rate: 1e-2
data:
  class_path: datamodule.DataModule
  init_args:
    data_dir: data
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  precision: null
  fast_dev_run: false
  max_epochs: 100
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: 10
  limit_test_batches: null
  limit_predict_batches: null
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: lightning_logs
      name: normalized
  callbacks:
    class_path: lightning.pytorch.callbacks.ModelCheckpoint
    init_args:
      save_top_k: 5
      monitor: valid_loss
      filename: "{epoch}-{step}-{valid_loss:.8f}"
  overfit_batches: 0.0
  val_check_interval: 50
  check_val_every_n_epoch: 1
  num_sanity_val_steps: null
  log_every_n_steps: 50
  enable_checkpointing: null
  enable_progress_bar: null
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: true
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: true
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null

model.py

import torch
from torch import nn
import torch.nn.functional as F
import lightning as pl


class Model(pl.LightningModule):
    def __init__(self, learning_rate: float):
        super().__init__()

        print()
        print("Model:")

        print(f"learning_rate: {learning_rate}")
        ## This outputs correctly.

        self.save_hyperparameters()
        
        print(self.hparams)
        ## This outputs:
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    model.Model
        # "init_args":     {'learning_rate': 0.01}

datamodule.py

from lightning import LightningDataModule
from torch.utils.data import DataLoader

from dataset import KaptchaDataset
from transform import Transform


class DataModule(LightningDataModule):
    def __init__(self, data_dir: str):
        super().__init__()
        self.save_hyperparameters()

        print()
        print("DataModule:")

        print(self.hparams)
        ## This outputs
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    datamodule.DataModule
        # "init_args":     {'data_dir': 'data'}

main.py

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from model import Model
from datamodule import DataModule


def cli_main():
    cli = LightningCLI()


if __name__ == "__main__":
    cli_main()

Run python main.py fit --config config.yaml


### Environment

<details>
  <summary>Current environment</summary>

#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
#- Running environment of LightningApp (e.g. local, cloud):


</details>


cc @carmocca @mauvilsa

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions