|
7 | 7 | import os |
8 | 8 | from dataclasses import dataclass |
9 | 9 | from datetime import timedelta |
| 10 | +import time |
| 11 | +from typing import Optional |
| 12 | + |
10 | 13 |
|
11 | 14 | import torch |
12 | 15 |
|
@@ -79,6 +82,96 @@ class NoColor: |
79 | 82 | white = "" |
80 | 83 | reset = "" |
81 | 84 |
|
| 85 | +class TrackTime: |
| 86 | + """integrated class for perf timing via perf_counter""" |
| 87 | + |
| 88 | + def __init__(self, use_ms: bool = False, round_to: Optional[int] = 4): |
| 89 | + self.use_ms = use_ms |
| 90 | + self.round_to = round_to |
| 91 | + self.start_time = 0.0 |
| 92 | + self.elapsed_time = 0.0 |
| 93 | + self.unit = "seconds" if not use_ms else "milliseconds" |
| 94 | + |
| 95 | + def __enter__(self): |
| 96 | + self.start_time = time.perf_counter() |
| 97 | + return self |
| 98 | + |
| 99 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 100 | + end_time = time.perf_counter() |
| 101 | + self.elapsed_time = end_time - self.start_time |
| 102 | + |
| 103 | + if self.use_ms: |
| 104 | + self.elapsed_time *= 1000 # Convert to milliseconds |
| 105 | + |
| 106 | + if self.round_to is not None: |
| 107 | + self.elapsed_time = round(self.elapsed_time, self.round_to) |
| 108 | + |
| 109 | + def get_time(self) -> float: |
| 110 | + return self.elapsed_time |
| 111 | + |
| 112 | + |
| 113 | +class CUDATrackTime: |
| 114 | + """ |
| 115 | + Integrated class for perf timing via cuda events. |
| 116 | + Note - this uses the default stream to synchronize, and defaults to current device. |
| 117 | + The event.record() will create a context on the CUDA device matching the device used at init. |
| 118 | + """ |
| 119 | + |
| 120 | + def __init__(self, device=None, use_ms: bool = False, round_to: Optional[int] = 4): |
| 121 | + if device is None: |
| 122 | + device = torch.cuda.current_device() |
| 123 | + elif isinstance(device, str): |
| 124 | + device = torch.device(device) |
| 125 | + elif isinstance(device, int): |
| 126 | + device = torch.device(f"cuda:{device}") |
| 127 | + |
| 128 | + self.device = device |
| 129 | + # Create events on the specified device |
| 130 | + with torch.cuda.device(self.device): |
| 131 | + self.start_event = torch.cuda.Event(enable_timing=True) |
| 132 | + self.end_event = torch.cuda.Event(enable_timing=True) |
| 133 | + |
| 134 | + self.active = False |
| 135 | + self.round_to = round_to |
| 136 | + self.elapsed_time = 0.0 |
| 137 | + self.use_ms = use_ms |
| 138 | + self.unit = "seconds" if not use_ms else "milliseconds" |
| 139 | + |
| 140 | + def start(self): |
| 141 | + if self.active: |
| 142 | + raise RuntimeError("Timer is already running. Use .stop() to stop it") |
| 143 | + self.start_event.record() |
| 144 | + self.active = True |
| 145 | + |
| 146 | + def stop(self): |
| 147 | + if not self.active: |
| 148 | + raise RuntimeError("Timer is not running. Use .start() to start it") |
| 149 | + self.end_event.record() |
| 150 | + self.active = False |
| 151 | + |
| 152 | + def get_time(self): |
| 153 | + if self.active: |
| 154 | + raise RuntimeError("Timer is still running. Use .stop() to stop it") |
| 155 | + |
| 156 | + torch.cuda.synchronize(self.device) # Synchronize all streams on the device |
| 157 | + total_time = self.start_event.elapsed_time(self.end_event) |
| 158 | + |
| 159 | + if not self.use_ms: |
| 160 | + total_time = total_time / 1000.0 # to seconds |
| 161 | + |
| 162 | + if self.round_to: |
| 163 | + total_time = round(total_time, self.round_to) |
| 164 | + |
| 165 | + self.elapsed_time = total_time |
| 166 | + |
| 167 | + return self.elapsed_time |
| 168 | + |
| 169 | + def __enter__(self): |
| 170 | + self.start() |
| 171 | + return self |
| 172 | + |
| 173 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 174 | + self.stop() |
82 | 175 |
|
83 | 176 | class GPUMemoryMonitor: |
84 | 177 | def __init__(self, device: str): |
|
0 commit comments