Skip to content

resume_from_checkpoint loads duplicate ModelCheckpoint #4014

@awaelchli

Description

@awaelchli

🐛 Bug

When reloading Trainer from a checkpoint, the model callback appears twice in the callbacks list.
This ONLY happens if one provides both a checkpoint_callback AND custom callbacks list.

To Reproduce

Code below

Code sample

import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


def run_test():
    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    class RandomCallback(Callback):
        pass

    model = BoringModel()
    checkpoint_callback = ModelCheckpoint()
    trainer_args = dict(
        max_epochs=2,
        logger=False,
        progress_bar_refresh_rate=0,
        checkpoint_callback=checkpoint_callback,
        weights_summary=None,
        callbacks=[RandomCallback()]  # try to remove it and see what happens :)
    )
    trainer = Trainer(**trainer_args)
    trainer.fit(model, train_dataloader=train_data)
    print(trainer.callbacks)
    trainer = Trainer(**trainer_args, resume_from_checkpoint=trainer.checkpoint_callback.best_model_path)
    trainer.fit(model, train_dataloader=train_data)
    print(trainer.callbacks)


if __name__ == '__main__':
    run_test()

Produces:

callbacks = [ModelCheckpoint]   # first training
callbacks = [ModelCheckpoint, ModelCheckpoint]  # after resuming

Expected behavior

callbacks = [ModelCheckpoint]   # first training
callbacks = [ModelCheckpoint]  # after resuming

Environment

master branch 0.10.0
python 3.7
torch 1.5.1

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcheckpointingRelated to checkpointinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions