Skip to content

EarlyStopping checkpointed state is lagging one epoch behind #1464

@lizhitwo

Description

@lizhitwo

🐛 Bug

Currently EarlyStopping's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.

To Reproduce

This is somewhat related to #1463 so I am going to use the same code.

Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = CoolSystem()

checkpoint_callback = ModelCheckpoint(
    filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
    save_top_k=-1,
    verbose=True,
    monitor='val_loss',
    mode='auto'
)

class EarlyStoppingPrinting(EarlyStopping):

    def on_train_start(self, trainer, pl_module):
        print('EarlyStoppingPrinting before on_train_start')
        print('self.wait = ', self.wait)
        super().on_train_start(trainer, pl_module)
        print('EarlyStoppingPrinting after on_train_start')
        print('self.wait = ', self.wait)

    def on_epoch_end(self, trainer, pl_module):
        ret = super().on_epoch_end(trainer, pl_module)
        if self.wait:
            print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
        return ret


early_stopping = EarlyStoppingPrinting(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='auto'
)

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=early_stopping)

trainer.fit(model)

Let the model train until convergence. And then reload the model and see how it continues:

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=None, 
                  resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_7.ckpt',
                  early_stop_callback=early_stopping)
trainer.fit(model)

The early_stopping callback would print:

EarlyStoppingPrinting before on_train_start
self.wait =  4
...

and keeps training.

Expected behavior

The early_stopping callback should print:

EarlyStoppingPrinting before on_train_start
self.wait =  5
...

and should not be trained again at all since self.wait >= self.patience.

If the model is loaded from an interrupted save, then it should still train after resuming, but with corrected self.wait.

Environment

This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD

Additional context

Somewhat related to #1463.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdiscussionIn a discussion stagehelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions