-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdesignIncludes a design discussionIncludes a design discussiondiscussionIn a discussion stageIn a discussion stage
Milestone
Description
There are several critical design issues surrounding both of the global_step and current_epoch values and their interaction with the training loop and ModelCheckpoint.
I'll use this issue to track and summarize some of the challenges and design decisions.
Global step issues
global_step refers to the number of optimizer steps applied. It takes the initial value of 0 (because 0 steps have been applied to the optimizer).
Most people expect global_step to be updated like this:
def training_step(data)
loss = model(data)
loss.backward()
return loss
global_step = 0
for x in range(epochs):
on_train_epoch_start()
for data in loader:
on_train_batch_start()
loss = training_step(data)
on_train_batch_end()
optimizer.step(loss)
global_step += 1
on_train_epoch_end()Any of the previous hooks can self.log(). These self.log() calls use the global_step value as the x-axis (if you were to plot it)
But the previous is not the case. If you print the global_step inside each hook this is the output:
(max_epochs=1, limit_train_batches=1)
on_train_epoch_start: 0
on_train_batch_start: 0
on_train_batch_end: 0
on_train_epoch_end: 0 <-- Should be 1!
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def on_train_epoch_start(self) -> None:
print("on_train_epoch_start", self.trainer.global_step)
def on_train_batch_start(self, *_, **__):
print("on_train_batch_start", self.trainer.global_step)
def on_train_batch_end(self, *_, **__):
print("on_train_batch_end", self.trainer.global_step)
def on_train_epoch_end(self):
print("on_train_epoch_end", self.trainer.global_step)
def on_train_end(self):
print("on_train_end", self.trainer.global_step)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=0,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
enable_progress_bar=False,
enable_checkpointing=False,
logger=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()Global steps issues when saving/loading checkpoints
Current epoch increment design on train end
SeanNaren, ananthsub, OverLordGoldDragon and semaphore-egg
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdesignIncludes a design discussionIncludes a design discussiondiscussionIn a discussion stageIn a discussion stage