Skip to content

add support for save_hyperparameters with Python Data Class #3494

@tbenst

Description

@tbenst

Python Data Classes are convenient in that they automatically generate a bunch of boilerplate code for assigning data to a class. They are particularly useful for PyTorch models that have a lot of hyperparameters and thus a lot of boilerplate.

🐛 Bug

I believe #1896 introduced a new bug: when using a data class, save_hyperparameters no longer works since it depends on init args and we instead use __post_init__ with dataclasses. Explicitly passing strings does not work either. Perhaps when passing

Code sample

import pytorch_lightning as pl
from dataclasses import dataclass

@dataclass()
class ConvDecoder(pl.LightningModule):
    imageChannels:int = 3
        
    def __post_init__(self):
        super().__init__()
        # both fail
#         self.save_hyperparameters()
        self.save_hyperparameters('imageChannels')
        
model = ConvDecoder()
model.hparams

Expected behavior

There should be a way to use save_hyperparameters with Data Classes

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onpriority: 2Low priority task

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions