diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index c993a74f53..662b64f8c7 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -68,6 +68,7 @@ def trace_handler(prof): ], schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, + record_shapes=True, ) as torch_profiler: torch_profiler.step_num = global_step yield torch_profiler