-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingduplicateThis issue or pull request already existsThis issue or pull request already existshelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
Currently lr_scheduler's state is updated after the checkpoint callback, so what is being saved here is last epoch's state.
Note: I think this has the same fix as #1464, but I'm posting it here because (1) I got rekt by this again, (2) in case it's not the same bug, and (3) #1464 is not fixed.
To Reproduce
Steps to reproduce the behavior:
Install using pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self):
super(CoolSystem, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
return [optimizer], [torch.optim.lr_scheduler.MultiStepLR(optimizer, [100], gamma=0.1)]
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
model = CoolSystem()
checkpoint_callback = ModelCheckpoint(
filepath='./model_ckpt/whatever_the_name_is_gonna_be_auto_chosen',
save_top_k=-1,
verbose=True,
monitor='val_loss',
mode='auto'
)
early_stopping = EarlyStopping(
monitor='val_loss',
patience=5,
verbose=True,
mode='auto'
)
class PrintingCallback(pl.Callback):
def on_epoch_start(self, trainer, pl_module):
print('Scheduler epoch %d' % trainer.lr_schedulers[0]['scheduler'].last_epoch)
print('Trainer epoch %d' % trainer.current_epoch)
print('-'*80)
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stopping,
callbacks=[PrintingCallback()])
trainer.fit(model)
Let the model train until convergence. And then reload a saved model and see how it continues:
trainer = Trainer(max_nb_epochs=1000, train_percent_check=0.1,
checkpoint_callback=None,
resume_from_checkpoint = 'model_ckpt/_ckpt_epoch_2.ckpt',
early_stop_callback=early_stopping,
callbacks=[PrintingCallback()])
trainer.fit(model)
The PrintingCallback would print:
Scheduler epoch 2
Trainer epoch 3
--------------------------------------------------------------------------------
Scheduler epoch 3
Trainer epoch 4
--------------------------------------------------------------------------------
...
and so on.
Expected behavior
The PrintingCallback should print:
Scheduler epoch 3
Trainer epoch 3
--------------------------------------------------------------------------------
Scheduler epoch 4
Trainer epoch 4
--------------------------------------------------------------------------------
...
Environment
This is ran on Google colab.
https://colab.research.google.com/drive/1pkCSMaApyjH40jwrdl4aQLVYjnGP3JzD?usp=sharing
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingduplicateThis issue or pull request already existsThis issue or pull request already existshelp wantedOpen to be worked onOpen to be worked on