Skip to content

You're resuming from a checkpoint that ended mid-epoch with every_n_epochs=1 #11809

@OverLordGoldDragon

Description

@OverLordGoldDragon

Colab MRE

direct code
import os, torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

BATCH_SIZE = 64

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = torch.nn.functional.cross_entropy(self(x), y)
        self.log('loss', loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

class Data():
    def __len__(self):
        return 1000

    def __getitem__(self, batch_idx):
        return torch.randn(28*28), torch.rand(10)

make_ckpt_cb = lambda: ModelCheckpoint(dirpath='.', monitor='loss', mode='min',
                                       every_n_epochs=1, save_top_k=-1)

mnist_model = MNISTModel()
train_loader = DataLoader(Data(), batch_size=BATCH_SIZE)
trainer = Trainer(gpus=0, max_epochs=1, callbacks=make_ckpt_cb())
trainer.fit(mnist_model, train_loader)

ckpt_path = [nm for nm in os.listdir() if nm.endswith('.ckpt')][0]

mnist_model = MNISTModel()
train_loader = DataLoader(Data(), batch_size=BATCH_SIZE)
trainer = Trainer(gpus=0, max_epochs=2, callbacks=make_ckpt_cb())
trainer.fit(mnist_model, train_loader, ckpt_path=ckpt_path)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions