-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
checkpointingRelated to checkpointingRelated to checkpointingpriority: 0High priority taskHigh priority taskprogress tracking (internal)Related to the progress tracking dataclassesRelated to the progress tracking dataclasses
Milestone
Description
🐛 Bug
The change introduced in #11805 causes a reset to the logged step number.
https://github.com/PyTorchLightning/pytorch-lightning/blob/49a4a36ad45b937dd0124ecfb08eb7400dbf3950/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L122
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(ckpt_path=None):
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
num_sanity_val_steps=0,
max_epochs=2,
enable_model_summary=False,
callbacks=ModelCheckpoint(dirpath="checkpoints", save_top_k=-1, filename="{epoch}", save_on_train_epoch_end=False),
log_every_n_steps=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=ckpt_path)
if __name__ == "__main__":
run()
run("checkpoints/epoch=0.ckpt")
The script will create two tensorboard logs:
- version_0: steps 0 to 63
- version_1: steps 0 to 31
Expected behavior
- version_1: steps 31 to 63
This was the behavior before #11805
Environment
- PyTorch Lightning Version (e.g., 1.5.0): master (49a4a36)
- Fault-tolerant training is off (PL_FAULT_TOLERANT_TRAINING=0)
cc @tchaton @rohitgr7 @akihironitta @awaelchli @ananthsub @ninginthecloud @carmocca
RomanCast and konstantinjdobler
Metadata
Metadata
Assignees
Labels
checkpointingRelated to checkpointingRelated to checkpointingpriority: 0High priority taskHigh priority taskprogress tracking (internal)Related to the progress tracking dataclassesRelated to the progress tracking dataclasses