-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Currently EarlyStopping's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.
To Reproduce
This is somewhat related to #1463 so I am going to use the same code.
Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
model = CoolSystem()
checkpoint_callback = ModelCheckpoint(
filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
save_top_k=-1,
verbose=True,
monitor='val_loss',
mode='auto'
)
class EarlyStoppingPrinting(EarlyStopping):
def on_train_start(self, trainer, pl_module):
print('EarlyStoppingPrinting before on_train_start')
print('self.wait = ', self.wait)
super().on_train_start(trainer, pl_module)
print('EarlyStoppingPrinting after on_train_start')
print('self.wait = ', self.wait)
def on_epoch_end(self, trainer, pl_module):
ret = super().on_epoch_end(trainer, pl_module)
if self.wait:
print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
return ret
early_stopping = EarlyStoppingPrinting(
monitor='val_loss',
patience=5,
verbose=True,
mode='auto'
)
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping)
trainer.fit(model)
Let the model train until convergence. And then reload the model and see how it continues:
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=None,
resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_7.ckpt',
early_stop_callback=early_stopping)
trainer.fit(model)
The early_stopping callback would print:
EarlyStoppingPrinting before on_train_start
self.wait = 4
...
and keeps training.
Expected behavior
The early_stopping callback should print:
EarlyStoppingPrinting before on_train_start
self.wait = 5
...
and should not be trained again at all since self.wait >= self.patience.
If the model is loaded from an interrupted save, then it should still train after resuming, but with corrected self.wait.
Environment
This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD
Additional context
Somewhat related to #1463.