Skip to content

WandbLogger incorrectly logs global_step when validating a checkpoint #16207

@nharada1

Description

@nharada1

Bug description

When validating a trained model, the WandbLogger incorrectly logs the global_step as zero, even when loading the full training state. The trainer seems to have the correct value, when inspecting self.global_step it appears correct in the validation stage. However, the logger always logs the value as 0 (but it does seem to log the epoch correctly).

To reproduce:

  1. Train a model that outputs checkpoints each epoch
  2. Run trainer.validate with a checkpoint specified

I have reproduced this in this git repo: https://github.com/nharada1/wandb-pl-repro

How to reproduce the bug

import glob
import os
import setuptools
import uuid

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FakeData
import segmentation_models_pytorch as smp

BATCH_SIZE = 32

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        aux_params=dict(
            pooling='avg',
            activation='softmax',
            classes=100,
        )
        self.model = smp.Unet(
            encoder_name="resnet50",
            in_channels=3,
            classes=10,
            aux_params=aux_params,
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        _, y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("loss", loss, prog_bar=True, logger=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_nb):
        x, y = batch
        _, y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("loss", loss, prog_bar=True, logger=True, on_epoch=True)
        return loss

    def on_validation_end(self):
        print(f"Final global step is {self.global_step}")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

def get_trainer(train, prefix):
    # Init our model
    mnist_model = Model()

    # Init DataLoader from MNIST Dataset
    train_ds = FakeData(size=10000, image_size=[3, 128, 128], num_classes=10, transform=transforms.ToTensor())
    loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

    postfix = 'train' if train else 'val'
    name = f"{prefix}-{postfix}"
    wandb_logger = pl.loggers.WandbLogger(project='project-debug', name=name)

    epoch_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath="/tmp/pl_wandb",
        save_top_k=-1,
        verbose=True,
    )

    callbacks = []
    if train:
        callbacks = [epoch_checkpoint_callback]

    # Initialize a trainer
    trainer = pl.Trainer(
        accelerator="auto",
        max_epochs=5,
        callbacks=callbacks,
        logger=wandb_logger,
    )

    return mnist_model, loader, trainer

def main():
    prefix = str(uuid.uuid4()).split('-')[0]

    # Train the model
    model, data, trainer = get_trainer(train=True, prefix=prefix)
    trainer.fit(model, data)

    # Eval the model
    checkpoint_file = sorted(glob.glob(os.path.join("/tmp/pl_wandb/*.ckpt")))[-1]
    model, data, trainer = get_trainer(train=False, prefix=prefix)
    trainer.validate(model, data, ckpt_path=checkpoint_file)


if __name__ == "__main__":
    main()

Error messages and logs

Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:06<00:00, 50.16it/s]
Final global step is 1565
wandb: Waiting for W&B process to finish... (success).
wandb: - 0.003 MB of 0.016 MB uploaded (0.000 MB deduped)
wandb: Run history:
wandb:               epoch ▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▆███████
wandb:                loss ▅▆▄▆▆▆▇▄█▄▆▁▆▅▄▅▃▇▄▇▅▁▄▇▇▆▆▆▅█▄▅
wandb: trainer/global_step ▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███▁
wandb:
wandb: Run summary:
wandb:               epoch 4
wandb:                loss 4.5202
wandb: trainer/global_step 0

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs triageWaiting to be triaged by maintainerswon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions