Skip to content

Commit 695bd01

Browse files
authored
Fix the incorrect step log for profiler after resuming from a checkpoint (#293)
Summary: The profiler currently maintains a counter locally and that counter is not synchronized with the checkpointed train step. This PR fixes the issue.
1 parent 787a571 commit 695bd01

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

torchtitan/profiling.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@contextlib.contextmanager
20-
def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
20+
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
2121
# get user defined profiler settings
2222
enable_profiling = config.profiling.enable_profiling
2323

@@ -27,19 +27,15 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
2727
trace_dir = os.path.join(dump_dir, save_trace_dir)
2828
profile_freq = config.profiling.profile_freq
2929

30-
_global_iter_count = 0
31-
3230
rank = torch.distributed.get_rank()
3331

3432
def trace_handler(prof):
35-
nonlocal _global_iter_count
36-
_global_iter_count += profile_freq
37-
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
33+
curr_trace_dir_name = "iteration_" + str(prof.step_num)
3834
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
3935
if not os.path.exists(curr_trace_dir):
4036
os.makedirs(curr_trace_dir, exist_ok=True)
4137

42-
logger.info(f"Dumping traces at step {_global_iter_count}")
38+
logger.info(f"Dumping traces at step {prof.step_num}")
4339
begin = time.monotonic()
4440
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
4541
logger.info(
@@ -69,6 +65,7 @@ def trace_handler(prof):
6965
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
7066
on_trace_ready=trace_handler,
7167
) as torch_profiler:
68+
torch_profiler.step_num = global_step
7269
yield torch_profiler
7370
else:
7471
torch_profiler = contextlib.nullcontext()

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def loss_fn(pred, labels):
271271
data_iterator = iter(data_loader)
272272

273273
logger.info(f"Training starts at step {train_state.step + 1}")
274-
with maybe_enable_profiling(job_config) as torch_profiler:
274+
with maybe_enable_profiling(
275+
job_config, global_step=train_state.step
276+
) as torch_profiler:
275277
checkpoint.reset()
276278

277279
# variables used to keep info for metrics logging

0 commit comments

Comments
 (0)