|
4 | 4 | from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized
|
5 | 5 | from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv
|
6 | 6 |
|
7 |
| - |
8 |
| -B = 1 |
| 7 | +B = 4 |
9 | 8 | NH = 32
|
10 |
| -S = 500 |
| 9 | +S = 800 |
11 | 10 | D = 128
|
12 | 11 | dtype = torch.float16
|
13 | 12 |
|
14 |
| -print("===== BENCHMARKING s.v with various sparsities =======") |
15 |
| -print("Batch Size : ", B) |
16 |
| -print("Number of Heads : ", NH) |
17 |
| -print("Number of Key Tokens (or sequence length) : ", S) |
18 |
| -print("Hidden dimension per head : ", D) |
| 13 | + |
19 | 14 |
|
20 | 15 | configs = [
|
21 | 16 | triton.testing.Benchmark(
|
|
25 | 20 | # Possible values for `line_arg`
|
26 | 21 | # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
|
27 | 22 | line_vals=["torch", "triton-optimized"], # Label name for the lines
|
28 |
| - line_names=["torch (full keys)", "Triton (Optimized)"], # Line styles |
| 23 | + line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles |
29 | 24 | styles=[("black", "-"), ("blue", "-")],
|
30 | 25 | ylabel="TFLOPS", # Label name for the y-axis
|
31 | 26 | plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot.
|
@@ -72,7 +67,22 @@ def benchmark_bmm2(sparsity, B, NH, S, D, provider):
|
72 | 67 |
|
73 | 68 | return ms, max_ms, min_ms
|
74 | 69 |
|
| 70 | + |
| 71 | + |
| 72 | +print( "===== BENCHMARKING [email protected]() with various sparsities =======") |
| 73 | +print("Batch Size : ", B) |
| 74 | +print("Number of Heads : ", NH) |
| 75 | +print("Number of Key Tokens (or sequence length) : ", S) |
| 76 | +print("Hidden dimension per head : ", D) |
| 77 | +result = benchmark_bmm1.run(print_data=True) |
| 78 | + |
| 79 | + |
| 80 | + |
| 81 | +print("===== BENCHMARKING s@v with various sparsities =======") |
| 82 | +print("Batch Size : ", B) |
| 83 | +print("Number of Heads : ", NH) |
| 84 | +print("Number of Key Tokens (or sequence length) : ", S) |
| 85 | +print("Hidden dimension per head : ", D) |
75 | 86 | result = benchmark_bmm2.run(print_data=True)
|
76 | 87 |
|
77 |
| -print(result) |
78 | 88 |
|
0 commit comments