Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
# get user defined profiler settings
enable_profiling = config.profiling.enable_profiling

Expand All @@ -27,19 +27,15 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
trace_dir = os.path.join(dump_dir, save_trace_dir)
profile_freq = config.profiling.profile_freq

_global_iter_count = 0

rank = torch.distributed.get_rank()

def trace_handler(prof):
nonlocal _global_iter_count
_global_iter_count += profile_freq
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
curr_trace_dir_name = "iteration_" + str(prof.step_num)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)

logger.info(f"Dumping traces at step {_global_iter_count}")
logger.info(f"Dumping traces at step {prof.step_num}")
begin = time.monotonic()
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
logger.info(
Expand Down Expand Up @@ -69,6 +65,7 @@ def trace_handler(prof):
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
else:
torch_profiler = contextlib.nullcontext()
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ def loss_fn(pred, labels):
data_iterator = iter(data_loader)

logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(job_config) as torch_profiler:
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler:
checkpoint.reset()

# variables used to keep info for metrics logging
Expand Down