Skip to content

Pass all running stages to DataModule.setup #5658

@carmocca

Description

@carmocca

🚀 Feature

Currently. DataModule.setup is only called with stages fit or test. But we have several more:

Stages:

https://github.com/PyTorchLightning/pytorch-lightning/blob/5f3372871a333c3229968f1af1b10a925d7ec3ec/pytorch_lightning/trainer/states.py#L39-L49

Note that it's a bit tricky because fit is not a RunningStage. It indicates train or eval

Motivation

Allows having custom logic for each stage

Pitch

def setup(stage: Optional[str] = None):
    assert stage in list(RunningStage)
    ...

Additional context

We are passing 'test' when predicting as seen in #5579
https://github.com/PyTorchLightning/pytorch-lightning/blob/9137b16068fe03e6db8df548235363e5f5476aac/pytorch_lightning/trainer/trainer.py#L909

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancementhelp wantedOpen to be worked onrefactor

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions