diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 3c093526a8..b4c2b2e027 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -17,7 +17,7 @@ @contextlib.contextmanager -def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs): +def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): # get user defined profiler settings enable_profiling = config.profiling.enable_profiling @@ -27,19 +27,15 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs): trace_dir = os.path.join(dump_dir, save_trace_dir) profile_freq = config.profiling.profile_freq - _global_iter_count = 0 - rank = torch.distributed.get_rank() def trace_handler(prof): - nonlocal _global_iter_count - _global_iter_count += profile_freq - curr_trace_dir_name = "iteration_" + str(_global_iter_count) + curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) - logger.info(f"Dumping traces at step {_global_iter_count}") + logger.info(f"Dumping traces at step {prof.step_num}") begin = time.monotonic() prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") logger.info( @@ -69,6 +65,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, ) as torch_profiler: + torch_profiler.step_num = global_step yield torch_profiler else: torch_profiler = contextlib.nullcontext() diff --git a/train.py b/train.py index ea6cdc3fc0..1db35f1f2d 100644 --- a/train.py +++ b/train.py @@ -260,7 +260,9 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) logger.info(f"Training starts at step {train_state.step + 1}") - with maybe_enable_profiling(job_config) as torch_profiler: + with maybe_enable_profiling( + job_config, global_step=train_state.step + ) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging