-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
When using a Trainer with check_val_every_n_epoch = n with n > 1 the trained checks the validation every n epochs and this works. But when used in combination with a ModelCheckpoint with save_top_k = m with m > 1 it also saves the model at every iteration. It should instead check every n. This behaviour happened in previous versions (if I remember correctly it worked in 1.2. But now is broken.
To Reproduce
This piece of code with the BoringModel reproduces the issue. It saves the model every epoch instead of every n epochs (see bash in the bottom).
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=10,
check_val_every_n_epoch=2,
weights_summary=None,
callbacks=[
ModelCheckpoint(
monitor="valid_loss",
mode="min",
dirpath="./",
save_top_k=10,
filename="model-{epoch:02d}-{valid_loss:.2f}",
)
]
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2579 Aug 27 09:39 model-epoch=02-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=04-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=06-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2643 Aug 27 09:39 model-epoch=08-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckptExpected behavior
The model should check validation loss and save the model every check_val_every_n_epoch epochs. This should be the correct models saved:
>>> ls -l *.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=01-valid_loss=-6.00.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=03-valid_loss=-11.57.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=05-valid_loss=-17.14.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=07-valid_loss=-22.70.ckpt
-rw-r--r--. 1 ndecao Domain Users 2378 Aug 27 09:39 model-epoch=09-valid_loss=-28.27.ckptEnvironment
- CUDA:
- GPU:
- TITAN X (Pascal)
- TITAN X (Pascal)
- TITAN X (Pascal)
- TITAN X (Pascal)
- available: True
- version: 10.1
- GPU:
- Packages:
- numpy: 1.19.5
- pyTorch_debug: False
- pyTorch_version: 1.8.1
- pytorch-lightning: 1.4.4
- tqdm: 4.62.1
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- version: 1 SMP Wed Feb 3 15:06:38 UTC 2021
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked on