Skip to content

Combining check_val_every_n_epoch and save_top_k is broken #9163

@nicola-decao

Description

@nicola-decao

🐛 Bug

When using a Trainer with check_val_every_n_epoch = n with n > 1 the trained checks the validation every n epochs and this works. But when used in combination with a ModelCheckpoint with save_top_k = m with m > 1 it also saves the model at every iteration. It should instead check every n. This behaviour happened in previous versions (if I remember correctly it worked in 1.2. But now is broken.

To Reproduce

This piece of code with the BoringModel reproduces the issue. It saves the model every epoch instead of every n epochs (see bash in the bottom).

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=10,
        check_val_every_n_epoch=2,
        weights_summary=None,
        callbacks=[
            ModelCheckpoint(
                monitor="valid_loss",
                mode="min",
                dirpath="./",
                save_top_k=10,
                filename="model-{epoch:02d}-{valid_loss:.2f}",
            )
        ]
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()
>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2579 Aug 27 09:39 model-epoch=02-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=04-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=06-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=08-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckpt

Expected behavior

The model should check validation loss and save the model every check_val_every_n_epoch epochs. This should be the correct models saved:

>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckpt

Environment

  • CUDA:
    • GPU:
      • TITAN X (Pascal)
      • TITAN X (Pascal)
      • TITAN X (Pascal)
      • TITAN X (Pascal)
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.8.1
    • pytorch-lightning: 1.4.4
    • tqdm: 4.62.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.10
    • version: 1 SMP Wed Feb 3 15:06:38 UTC 2021

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcheckpointingRelated to checkpointinghelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions