Skip to content

EarlyStopping reinitializes to .wait=0 even with Trainer resume_from_checkpoint #1463

@lizhitwo

Description

@lizhitwo

🐛 Bug

When using Trainer's resume_from_checkpoint with EarlyStopping callback, the callback's patience progress (i.e. self.wait) is loaded according to the checkpoint, but is getting reset by its on_train_start method, making the checkpoint restoration moot.

Also, the EarlyStopping's .best is not saved or restored at all, making its restoration further unusable.

To Reproduce

Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

model = CoolSystem()

checkpoint_callback = ModelCheckpoint(
    filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
    save_top_k=-1,
    verbose=True,
    monitor='val_loss',
    mode='auto'
)

class EarlyStoppingPrinting(EarlyStopping):

    def on_train_start(self, trainer, pl_module):
        print('EarlyStoppingPrinting before on_train_start')
        print('self.wait = ', self.wait)
        super().on_train_start(trainer, pl_module)
        print('EarlyStoppingPrinting after on_train_start')
        print('self.wait = ', self.wait)

    def on_epoch_end(self, trainer, pl_module):
        ret = super().on_epoch_end(trainer, pl_module)
        if self.wait:
            print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
        return ret


early_stopping = EarlyStoppingPrinting(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='auto'
)

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=early_stopping)

trainer.fit(model)

And then use KeyboardInterrupt on the training when early_stopping.wait>0. Load the corresponding checkpoint (let's say it's model_ckpt/_ckpt_epoch_5.ckpt) and resume with

trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1, 
                  checkpoint_callback=None, 
                  resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_5.ckpt',
                  early_stop_callback=early_stopping)
trainer.fit(model)

The early_stopping callback would print:

EarlyStoppingPrinting before on_train_start
self.wait =  2
EarlyStoppingPrinting after on_train_start
self.wait =  0

And for self.best, I mean it's not even saved; do I need to write the code?

Expected behavior

Checkpoint value of self.wait should be preserved rather than reset:

EarlyStoppingPrinting before on_train_start
self.wait =  2
EarlyStoppingPrinting after on_train_start
self.wait =  2

And self.best should be saved and loaded from the checkpoint.

Environment

This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD

Additional context

It is confusing what member variables of the model Lightning saves into the checkpoints from reading the tutorials -- it's implied it saves a wide range of things, but what is being saved is actually very specific.

Also confusingly there are many ways to restore a checkpoint (model's load_from_checkpoint method, trainer's resume_from_checkpoint parameter, and using test_tube). These are not well documented (at least I didn't find this page before searching github) and I have no idea if I used the right one.

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancement

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions