Skip to content

global_step/current_epoch issues #7406

@carmocca

Description

@carmocca

There are several critical design issues surrounding both of the global_step and current_epoch values and their interaction with the training loop and ModelCheckpoint.

I'll use this issue to track and summarize some of the challenges and design decisions.

Global step issues

global_step refers to the number of optimizer steps applied. It takes the initial value of 0 (because 0 steps have been applied to the optimizer).

Most people expect global_step to be updated like this:

def training_step(data)
    loss = model(data)
    loss.backward()
    return loss   

global_step = 0
for x in range(epochs):
    on_train_epoch_start()
    for data in loader:
        on_train_batch_start()
        loss = training_step(data)
        on_train_batch_end()
        optimizer.step(loss)
        global_step += 1
    on_train_epoch_end()

Any of the previous hooks can self.log(). These self.log() calls use the global_step value as the x-axis (if you were to plot it)

But the previous is not the case. If you print the global_step inside each hook this is the output:
(max_epochs=1, limit_train_batches=1)

on_train_epoch_start: 0
on_train_batch_start: 0
on_train_batch_end: 0
on_train_epoch_end: 0 <-- Should be 1!
import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_train_epoch_start(self) -> None:
        print("on_train_epoch_start", self.trainer.global_step)

    def on_train_batch_start(self, *_, **__):
        print("on_train_batch_start", self.trainer.global_step)

    def on_train_batch_end(self, *_, **__):
        print("on_train_batch_end", self.trainer.global_step)

    def on_train_epoch_end(self):
        print("on_train_epoch_end", self.trainer.global_step)

    def on_train_end(self):
        print("on_train_end", self.trainer.global_step)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=0,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        enable_checkpointing=False,
        logger=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Global steps issues when saving/loading checkpoints

Remove https://github.com/PyTorchLightning/pytorch-lightning/blob/5ad5ba54c0c477546d21daf75ac7b4748d7963a7/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L343

Current epoch increment design on train end

#8578

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdesignIncludes a design discussiondiscussionIn a discussion stage

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions