-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When resuming model training from a checkpoint, the TensorboardLogger and WandbLogger will log metrics as if the global_step was reset to 0 (although the global_step in the trainer and pl_module are accurate). This issue arises when manually resuming training from a checkpoint using the ckpt_path arg in Trainer.fit and when doing fault-tolerant training as shown here: https://github.com/PyTorchLightning/pytorch-lightning/blob/1.6.3/pl_examples/fault_tolerant/automatic.py
To Reproduce
I've adapted the script linked above to test this, running v 1.6.3 of pytorch-lightning:
import os
import random as python_random
from argparse import ArgumentParser
from time import sleep
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb
class RandomGetItemDataset(Dataset):
"""A dataset with random elements generated using global rng from torch, numpy and python."""
def __init__(self, length, size):
self.size = size
self.len = length
def __getitem__(self, index):
t = torch.rand(self.size)
n = torch.from_numpy(np.random.rand(self.size))
p = torch.tensor([python_random.random() for _ in range(self.size)])
sample = (index + (t + n + p) / 10).float()
return sample
def __len__(self):
return self.len
class SimpleMLP(LightningModule):
def __init__(self, fail_on_step: int = -1):
super().__init__()
self.layer = torch.nn.Linear(1, 2)
self.seen_batches = []
self.fail_on_step = fail_on_step
def training_step(self, batch, batch_idx):
if self.global_step == self.fail_on_step:
log.info(
f"READY TO BE KILLED WITH SIGTERM SIGNAL. " f"Run `kill -SIGTERM {os.getpid()}` in another terminal."
)
# this line is used to wait for you to send the signal to exit gracefully.
while not self.trainer._terminate_gracefully:
sleep(0.1)
batch = batch["data"] if isinstance(batch, dict) else batch
self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
loss = sum(self.layer(b).sum() for b in batch)
self.log("loss", loss.item())
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
return DataLoader(RandomGetItemDataset(3, 1))
def _run_training(default_root_dir=".", max_epochs=3, fail_on_step: int = -1, ckpt_path=None, logger=True):
model = SimpleMLP(fail_on_step=fail_on_step)
trainer = Trainer(default_root_dir=default_root_dir, max_epochs=max_epochs,
logger=logger, log_every_n_steps=1)
trainer.fit(model, ckpt_path=ckpt_path)
wandb.finish()
return model.seen_batches, model.parameters()
def main(args):
seed_everything(42)
os.environ["PL_FAULT_TOLERANT_TRAINING"] = "automatic" # active fault tolerant automatic
ckpt_path = ".pl_auto_save.ckpt"
auto_restart_ckpt_path_exists = os.path.exists(ckpt_path)
if args.emulate_kill_signal:
fail_on_step = -1 if auto_restart_ckpt_path_exists else 4
completed_batches = 4 if auto_restart_ckpt_path_exists else 5
else:
fail_on_step = -1
completed_batches = 9
if args.use_tb:
logger = True
else:
logger = WandbLogger(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.wandb_run,
id=args.wandb_run,
)
complete_batches, weights = _run_training(fail_on_step=fail_on_step, logger=logger)
assert len(complete_batches) == completed_batches
if not auto_restart_ckpt_path_exists and args.emulate_kill_signal:
assert os.path.exists(ckpt_path)
if auto_restart_ckpt_path_exists or not args.emulate_kill_signal:
log.info([w for w in weights])
if __name__ == "__main__":
parser = ArgumentParser(description="Fault Tolerant Under Signal Example")
parser.add_argument(
"--emulate_kill_signal",
action="store_true",
help="Whether you should gracefully kill the process with a `SIGTERM` signal.",
)
parser.add_argument(
"--use_tb",
action="store_true",
help="Use TensorBoard instead of WandB.",
)
parser.add_argument(
"-e", "--wandb_entity",
type=str,
default=None,
help="Wandb entity.",
)
parser.add_argument(
"-p", "--wandb_project",
type=str,
default=None,
help="Wandb project.",
)
parser.add_argument(
"-r", "--wandb_run",
type=str,
default=None,
help="Wandb run.",
)
main(parser.parse_args())With tensorboard, running these:
python automatic.py --use_tb (without fault)
python automatic.py --use_tb --emulate_kill_signal (with fault)
python automatic.py --use_tb --emulate_kill_signal (resume from fault)
Results in the following, where the epoch is properly logged, but not the step:
With wandb, running these:
python automatic.py -e [wandb_entity] -p [wandb_project] -r no_fault (without fault)
python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (with fault)
python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (resume from fault)
Results in the following, where the step is properly logged (because I'm only logging once per step, see #13016), but the global_step is reset.
Expected behavior
The trainer/global_step in WandbLogger and step in TensorBoardLogger should properly reflect the global_step state of the trainer/pl_module when resuming from checkpoings (either manually or automatically with fault-tolerant training).
Environment
CUDA:
- GPU:
- available: False
- version: 10.2
* Packages:
- numpy: 1.22.4
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu102
- pytorch-lightning: 1.6.3
- tqdm: 4.64.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.4
- version: #171-Ubuntu SMP Fri Nov 5 11:55:11 UTC 2021

