-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task
Milestone
Description
🐛 Bug
When reloading Trainer from a checkpoint, the model callback appears twice in the callbacks list.
This ONLY happens if one provides both a checkpoint_callback AND custom callbacks list.
To Reproduce
Code below
Code sample
import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def training_step_end(self, training_step_outputs):
return training_step_outputs
def training_epoch_end(self, outputs) -> None:
torch.stack([x["loss"] for x in outputs]).mean()
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def run_test():
# fake data
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
class RandomCallback(Callback):
pass
model = BoringModel()
checkpoint_callback = ModelCheckpoint()
trainer_args = dict(
max_epochs=2,
logger=False,
progress_bar_refresh_rate=0,
checkpoint_callback=checkpoint_callback,
weights_summary=None,
callbacks=[RandomCallback()] # try to remove it and see what happens :)
)
trainer = Trainer(**trainer_args)
trainer.fit(model, train_dataloader=train_data)
print(trainer.callbacks)
trainer = Trainer(**trainer_args, resume_from_checkpoint=trainer.checkpoint_callback.best_model_path)
trainer.fit(model, train_dataloader=train_data)
print(trainer.callbacks)
if __name__ == '__main__':
run_test()Produces:
callbacks = [ModelCheckpoint] # first training
callbacks = [ModelCheckpoint, ModelCheckpoint] # after resumingExpected behavior
callbacks = [ModelCheckpoint] # first training
callbacks = [ModelCheckpoint] # after resumingEnvironment
master branch 0.10.0
python 3.7
torch 1.5.1
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task