Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4e7332f

Browse files
authored
[distributed] add TrackTime, CUDATrackTime to monitor perf for weight loading per stage and future perf measurements (#1121)
* add TrackTime, monitor perf for weight loading per stage * add CUDATrackTime * ruff formatting * add device for CUDATrackTime per PR feedback * add comment re: cuda context, ruff format
1 parent e2049f4 commit 4e7332f

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

dist_run.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
get_hf_weight_map_and_path,
2424
load_safetensor_weights,
2525
)
26-
from distributed.utils import Color as color, GPUMemoryMonitor
26+
27+
from distributed.utils import Color as color, TrackTime, CUDATrackTime, GPUMemoryMonitor
28+
2729
from distributed.verification_utils import find_cpu_tensors
2830
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
2931
from torchchat.model import ModelArgs, Transformer
@@ -188,8 +190,14 @@ def main():
188190

189191
# Load weights
190192
logger.info(f"Loading weights for {pp_rank=} on {device=}")
191-
_load_model_weights(model, hf_model_name, device=device, model_config=config)
193+
with TrackTime("cuda") as timer:
194+
_load_model_weights(model, hf_model_name, device=device, model_config=config)
192195

196+
logger.info(
197+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
198+
)
199+
200+
193201
# Setup input position
194202
# input_pos for prefill: a list of increasing integers from 0 to seqlen
195203
input_pos = torch.arange(seqlen, device=device)

distributed/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import os
88
from dataclasses import dataclass
99
from datetime import timedelta
10+
import time
11+
from typing import Optional
12+
1013

1114
import torch
1215

@@ -79,6 +82,96 @@ class NoColor:
7982
white = ""
8083
reset = ""
8184

85+
class TrackTime:
86+
"""integrated class for perf timing via perf_counter"""
87+
88+
def __init__(self, use_ms: bool = False, round_to: Optional[int] = 4):
89+
self.use_ms = use_ms
90+
self.round_to = round_to
91+
self.start_time = 0.0
92+
self.elapsed_time = 0.0
93+
self.unit = "seconds" if not use_ms else "milliseconds"
94+
95+
def __enter__(self):
96+
self.start_time = time.perf_counter()
97+
return self
98+
99+
def __exit__(self, exc_type, exc_val, exc_tb):
100+
end_time = time.perf_counter()
101+
self.elapsed_time = end_time - self.start_time
102+
103+
if self.use_ms:
104+
self.elapsed_time *= 1000 # Convert to milliseconds
105+
106+
if self.round_to is not None:
107+
self.elapsed_time = round(self.elapsed_time, self.round_to)
108+
109+
def get_time(self) -> float:
110+
return self.elapsed_time
111+
112+
113+
class CUDATrackTime:
114+
"""
115+
Integrated class for perf timing via cuda events.
116+
Note - this uses the default stream to synchronize, and defaults to current device.
117+
The event.record() will create a context on the CUDA device matching the device used at init.
118+
"""
119+
120+
def __init__(self, device=None, use_ms: bool = False, round_to: Optional[int] = 4):
121+
if device is None:
122+
device = torch.cuda.current_device()
123+
elif isinstance(device, str):
124+
device = torch.device(device)
125+
elif isinstance(device, int):
126+
device = torch.device(f"cuda:{device}")
127+
128+
self.device = device
129+
# Create events on the specified device
130+
with torch.cuda.device(self.device):
131+
self.start_event = torch.cuda.Event(enable_timing=True)
132+
self.end_event = torch.cuda.Event(enable_timing=True)
133+
134+
self.active = False
135+
self.round_to = round_to
136+
self.elapsed_time = 0.0
137+
self.use_ms = use_ms
138+
self.unit = "seconds" if not use_ms else "milliseconds"
139+
140+
def start(self):
141+
if self.active:
142+
raise RuntimeError("Timer is already running. Use .stop() to stop it")
143+
self.start_event.record()
144+
self.active = True
145+
146+
def stop(self):
147+
if not self.active:
148+
raise RuntimeError("Timer is not running. Use .start() to start it")
149+
self.end_event.record()
150+
self.active = False
151+
152+
def get_time(self):
153+
if self.active:
154+
raise RuntimeError("Timer is still running. Use .stop() to stop it")
155+
156+
torch.cuda.synchronize(self.device) # Synchronize all streams on the device
157+
total_time = self.start_event.elapsed_time(self.end_event)
158+
159+
if not self.use_ms:
160+
total_time = total_time / 1000.0 # to seconds
161+
162+
if self.round_to:
163+
total_time = round(total_time, self.round_to)
164+
165+
self.elapsed_time = total_time
166+
167+
return self.elapsed_time
168+
169+
def __enter__(self):
170+
self.start()
171+
return self
172+
173+
def __exit__(self, exc_type, exc_val, exc_tb):
174+
self.stop()
82175

83176
class GPUMemoryMonitor:
84177
def __init__(self, device: str):

0 commit comments

Comments
 (0)