From 1cdf21a4ddbfdce90ff1a885b45a092ff53dcd8d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 1 May 2024 16:23:16 -0700 Subject: [PATCH 1/2] Fix the incorrect step log for profiler after resuming from a checkpoint Summary: The profiler currently maintains a counter locally and that counter is not synchronized with the checkpointed train step. This PR fixes the issue. --- torchtitan/profiling.py | 5 ++--- train.py | 4 +++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 3c093526a8..0940237c9a 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -17,7 +17,7 @@ @contextlib.contextmanager -def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs): +def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): # get user defined profiler settings enable_profiling = config.profiling.enable_profiling @@ -27,8 +27,7 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs): trace_dir = os.path.join(dump_dir, save_trace_dir) profile_freq = config.profiling.profile_freq - _global_iter_count = 0 - + _global_iter_count = global_step rank = torch.distributed.get_rank() def trace_handler(prof): diff --git a/train.py b/train.py index ea6cdc3fc0..1db35f1f2d 100644 --- a/train.py +++ b/train.py @@ -260,7 +260,9 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) logger.info(f"Training starts at step {train_state.step + 1}") - with maybe_enable_profiling(job_config) as torch_profiler: + with maybe_enable_profiling( + job_config, global_step=train_state.step + ) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging From 34b2773cf5f6816fb3d99b2d64d43ff632872f2a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 2 May 2024 15:16:35 -0700 Subject: [PATCH 2/2] Update to use profiler.step_num --- torchtitan/profiling.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 0940237c9a..b4c2b2e027 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -27,18 +27,15 @@ def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): trace_dir = os.path.join(dump_dir, save_trace_dir) profile_freq = config.profiling.profile_freq - _global_iter_count = global_step rank = torch.distributed.get_rank() def trace_handler(prof): - nonlocal _global_iter_count - _global_iter_count += profile_freq - curr_trace_dir_name = "iteration_" + str(_global_iter_count) + curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) - logger.info(f"Dumping traces at step {_global_iter_count}") + logger.info(f"Dumping traces at step {prof.step_num}") begin = time.monotonic() prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") logger.info( @@ -68,6 +65,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, ) as torch_profiler: + torch_profiler.step_num = global_step yield torch_profiler else: torch_profiler = contextlib.nullcontext()