-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
needs triageWaiting to be triaged by maintainersWaiting to be triaged by maintainerswon't fixThis will not be worked onThis will not be worked on
Description
Bug description
When validating a trained model, the WandbLogger incorrectly logs the global_step as zero, even when loading the full training state. The trainer seems to have the correct value, when inspecting self.global_step it appears correct in the validation stage. However, the logger always logs the value as 0 (but it does seem to log the epoch correctly).
To reproduce:
- Train a model that outputs checkpoints each epoch
- Run
trainer.validatewith a checkpoint specified
I have reproduced this in this git repo: https://github.com/nharada1/wandb-pl-repro
How to reproduce the bug
import glob
import os
import setuptools
import uuid
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FakeData
import segmentation_models_pytorch as smp
BATCH_SIZE = 32
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
aux_params=dict(
pooling='avg',
activation='softmax',
classes=100,
)
self.model = smp.Unet(
encoder_name="resnet50",
in_channels=3,
classes=10,
aux_params=aux_params,
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
x, y = batch
_, y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log("loss", loss, prog_bar=True, logger=True, on_step=True)
return loss
def validation_step(self, batch, batch_nb):
x, y = batch
_, y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log("loss", loss, prog_bar=True, logger=True, on_epoch=True)
return loss
def on_validation_end(self):
print(f"Final global step is {self.global_step}")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def get_trainer(train, prefix):
# Init our model
mnist_model = Model()
# Init DataLoader from MNIST Dataset
train_ds = FakeData(size=10000, image_size=[3, 128, 128], num_classes=10, transform=transforms.ToTensor())
loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
postfix = 'train' if train else 'val'
name = f"{prefix}-{postfix}"
wandb_logger = pl.loggers.WandbLogger(project='project-debug', name=name)
epoch_checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath="/tmp/pl_wandb",
save_top_k=-1,
verbose=True,
)
callbacks = []
if train:
callbacks = [epoch_checkpoint_callback]
# Initialize a trainer
trainer = pl.Trainer(
accelerator="auto",
max_epochs=5,
callbacks=callbacks,
logger=wandb_logger,
)
return mnist_model, loader, trainer
def main():
prefix = str(uuid.uuid4()).split('-')[0]
# Train the model
model, data, trainer = get_trainer(train=True, prefix=prefix)
trainer.fit(model, data)
# Eval the model
checkpoint_file = sorted(glob.glob(os.path.join("/tmp/pl_wandb/*.ckpt")))[-1]
model, data, trainer = get_trainer(train=False, prefix=prefix)
trainer.validate(model, data, ckpt_path=checkpoint_file)
if __name__ == "__main__":
main()Error messages and logs
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:06<00:00, 50.16it/s]
Final global step is 1565
wandb: Waiting for W&B process to finish... (success).
wandb: - 0.003 MB of 0.016 MB uploaded (0.000 MB deduped)
wandb: Run history:
wandb: epoch ▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▆███████
wandb: loss ▅▆▄▆▆▆▇▄█▄▆▁▆▅▄▅▃▇▄▇▅▁▄▇▇▆▆▆▅█▄▅
wandb: trainer/global_step ▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███▁
wandb:
wandb: Run summary:
wandb: epoch 4
wandb: loss 4.5202
wandb: trainer/global_step 0
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
Metadata
Metadata
Assignees
Labels
needs triageWaiting to be triaged by maintainersWaiting to be triaged by maintainerswon't fixThis will not be worked onThis will not be worked on