-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When using Trainer's resume_from_checkpoint with EarlyStopping callback, the callback's patience progress (i.e. self.wait) is loaded according to the checkpoint, but is getting reset by its on_train_start method, making the checkpoint restoration moot.
Also, the EarlyStopping's .best is not saved or restored at all, making its restoration further unusable.
To Reproduce
Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
model = CoolSystem()
checkpoint_callback = ModelCheckpoint(
filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
save_top_k=-1,
verbose=True,
monitor='val_loss',
mode='auto'
)
class EarlyStoppingPrinting(EarlyStopping):
def on_train_start(self, trainer, pl_module):
print('EarlyStoppingPrinting before on_train_start')
print('self.wait = ', self.wait)
super().on_train_start(trainer, pl_module)
print('EarlyStoppingPrinting after on_train_start')
print('self.wait = ', self.wait)
def on_epoch_end(self, trainer, pl_module):
ret = super().on_epoch_end(trainer, pl_module)
if self.wait:
print('Early stopping patience: %d/%d' % (self.patience-self.wait, self.patience))
return ret
early_stopping = EarlyStoppingPrinting(
monitor='val_loss',
patience=5,
verbose=True,
mode='auto'
)
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping)
trainer.fit(model)
And then use KeyboardInterrupt on the training when early_stopping.wait>0. Load the corresponding checkpoint (let's say it's model_ckpt/_ckpt_epoch_5.ckpt) and resume with
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=None,
resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_5.ckpt',
early_stop_callback=early_stopping)
trainer.fit(model)
The early_stopping callback would print:
EarlyStoppingPrinting before on_train_start
self.wait = 2
EarlyStoppingPrinting after on_train_start
self.wait = 0
And for self.best, I mean it's not even saved; do I need to write the code?
Expected behavior
Checkpoint value of self.wait should be preserved rather than reset:
EarlyStoppingPrinting before on_train_start
self.wait = 2
EarlyStoppingPrinting after on_train_start
self.wait = 2
And self.best should be saved and loaded from the checkpoint.
Environment
This is ran on Google colab.
https://colab.research.google.com/drive/1ZdiFf6ksNpgsqOdSKM6lMO0yIhqpnTHD
Additional context
It is confusing what member variables of the model Lightning saves into the checkpoints from reading the tutorials -- it's implied it saves a wide range of things, but what is being saved is actually very specific.
Also confusingly there are many ways to restore a checkpoint (model's load_from_checkpoint method, trainer's resume_from_checkpoint parameter, and using test_tube). These are not well documented (at least I didn't find this page before searching github) and I have no idea if I used the right one.