|
6 | 6 | # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py |
7 | 7 |
|
8 | 8 | import itertools |
9 | | -import time |
10 | 9 | from dataclasses import dataclass |
11 | 10 | from typing import List |
12 | 11 |
|
13 | 12 | import torch |
14 | 13 | from tabulate import tabulate |
15 | 14 | from tqdm import tqdm |
| 15 | +from triton.testing import do_bench |
16 | 16 |
|
17 | 17 | from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( |
18 | 18 | triton_fp8_col_major_jagged_colwise_scales, |
@@ -129,18 +129,15 @@ def run_triton( |
129 | 129 |
|
130 | 130 | # bench torch |
131 | 131 | compiled_run_torch = torch.compile(run_torch) |
132 | | - warmup(compiled_run_torch, input_row_major, input_col_major, offs) |
133 | | - start_time_ns = time.perf_counter_ns() |
134 | | - compiled_run_torch(input_row_major, input_col_major, offs) |
135 | | - torch_time_ns = time.perf_counter_ns() - start_time_ns |
136 | | - torch_time_us = torch_time_ns / 1e3 |
| 132 | + torch_time_us = benchmark_cuda_function_in_microseconds( |
| 133 | + compiled_run_torch, input_row_major, input_col_major, offs |
| 134 | + ) |
137 | 135 |
|
138 | 136 | # bench triton |
139 | 137 | warmup(run_triton, input_row_major, input_col_major, offs) |
140 | | - start_time_ns = time.perf_counter_ns() |
141 | | - run_triton(input_row_major, input_col_major, offs) |
142 | | - triton_time_ns = time.perf_counter_ns() - start_time_ns |
143 | | - triton_time_us = triton_time_ns / 1e3 |
| 138 | + triton_time_us = benchmark_cuda_function_in_microseconds( |
| 139 | + run_triton, input_row_major, input_col_major, offs |
| 140 | + ) |
144 | 141 |
|
145 | 142 | return ExperimentResult( |
146 | 143 | torch_time_us=torch_time_us, |
@@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]): |
173 | 170 | print(tabulate(rows, headers=headers)) |
174 | 171 |
|
175 | 172 |
|
| 173 | +def benchmark_cuda_function_in_microseconds(f, *args): |
| 174 | + return do_bench(lambda: f(*args), return_mode="median") * 1e3 |
| 175 | + |
| 176 | + |
176 | 177 | def main(): |
177 | 178 | torch.random.manual_seed(123) |
178 | 179 | configs = get_configs() |
|
0 commit comments