Skip to content

Commit 441adf5

Browse files
committed
reduce autotuning range and add second bmm to benchmark
1 parent b26753d commit 441adf5

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

methods/pca_topk/kernel/benchmark.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,13 @@
44
from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized
55
from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv
66

7-
8-
B = 1
7+
B = 4
98
NH = 32
10-
S = 500
9+
S = 800
1110
D = 128
1211
dtype = torch.float16
1312

14-
print("===== BENCHMARKING s.v with various sparsities =======")
15-
print("Batch Size : ", B)
16-
print("Number of Heads : ", NH)
17-
print("Number of Key Tokens (or sequence length) : ", S)
18-
print("Hidden dimension per head : ", D)
13+
1914

2015
configs = [
2116
triton.testing.Benchmark(
@@ -25,7 +20,7 @@
2520
# Possible values for `line_arg`
2621
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
2722
line_vals=["torch", "triton-optimized"], # Label name for the lines
28-
line_names=["torch (full keys)", "Triton (Optimized)"], # Line styles
23+
line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles
2924
styles=[("black", "-"), ("blue", "-")],
3025
ylabel="TFLOPS", # Label name for the y-axis
3126
plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot.
@@ -72,7 +67,22 @@ def benchmark_bmm2(sparsity, B, NH, S, D, provider):
7267

7368
return ms, max_ms, min_ms
7469

70+
71+
72+
print("===== BENCHMARKING [email protected]() with various sparsities =======")
73+
print("Batch Size : ", B)
74+
print("Number of Heads : ", NH)
75+
print("Number of Key Tokens (or sequence length) : ", S)
76+
print("Hidden dimension per head : ", D)
77+
result = benchmark_bmm1.run(print_data=True)
78+
79+
80+
81+
print("===== BENCHMARKING s@v with various sparsities =======")
82+
print("Batch Size : ", B)
83+
print("Number of Heads : ", NH)
84+
print("Number of Key Tokens (or sequence length) : ", S)
85+
print("Hidden dimension per head : ", D)
7586
result = benchmark_bmm2.run(print_data=True)
7687

77-
print(result)
7888

methods/pca_topk/kernel/pca_topk.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton.language as tl
77
from torch import Tensor
88

9-
def get_autotune_config():
9+
def get_autotune_config_outer():
1010
return [
1111
triton.Config({"n_chunk": 4}),
1212
triton.Config({"n_chunk": 8}),
@@ -20,7 +20,7 @@ def get_autotune_config():
2020
]
2121

2222
@triton.autotune(
23-
configs=get_autotune_config(),
23+
configs=get_autotune_config_outer(),
2424
key=['b', 'n', 'k'],
2525
)
2626
@triton.jit
@@ -110,8 +110,18 @@ def gather_outer_bmv_optimized(A: Tensor, B: Tensor, I: Tensor) -> Tensor:
110110
return Y
111111

112112

113+
def get_autotune_config_inner():
114+
return [
115+
triton.Config({"n_chunk": 4}),
116+
triton.Config({"n_chunk": 8}),
117+
triton.Config({"n_chunk": 16}),
118+
triton.Config({"n_chunk": 32}),
119+
triton.Config({"n_chunk": 64}),
120+
triton.Config({"n_chunk": 128}),
121+
]
122+
113123
@triton.autotune(
114-
configs=get_autotune_config(),
124+
configs=get_autotune_config_inner(),
115125
key=['b', 'n', 'k'],
116126
)
117127
@triton.jit

0 commit comments

Comments
 (0)