Skip to content

Fix interaction with save_last and every_n_epochs #12391

@carmocca

Description

@carmocca

@carmocca for the following snippet, a last.ckpt is generated before this PR, but not anymore after:

import uuid

import torch
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


if __name__ == "__main__":
    tmpdir = f"/tmp/{uuid.uuid4()}"
    print(tmpdir)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_epochs=10, save_last=True)],
        enable_checkpointing=True,
    )
    model = BoringModel()
    trainer.fit(
        model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2)
    )

The following contract was respected prior to this PR, but not anymore after:

save_last: When True, always saves the model at the end of the epoch to a file last.ckpt

It is a BC-breaking change that changes a behavior some users rely on, regardless of whether it is believed to be a "bug".

Originally posted by @yifuwang in #11805 (comment)

cc @tchaton @rohitgr7 @akihironitta @carmocca @awaelchli @ninginthecloud @jjenniferdai

Metadata

Metadata

Assignees

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions