Skip to content

Commit 59b2bfd

Browse files
cthifacebook-github-bot
authored andcommitted
MXFP8 Grouped GEMM tuning (#4821)
Summary: Pull Request resolved: #4821 X-link: facebookresearch/FBGEMM#1848 [Re-tune MXFP8 grouped gemm](https://docs.google.com/spreadsheets/d/1xk8h1OZFnvKyH7kpP-FFZmXPu5Tmv1AoJ7--KEX4pXg/edit?gid=0#gid=0) with tuning tooling to autogen a heuristic on B200 @ peak 750W - We see peak tflops of ~2K. Some shapes still cannot achieve it, so likely room for further improvements. - Compared to BF16 grouped gemm baseline, roughly ~1.5-2x improvement. - Compared to old heuristic, ~1.1-1.3x improvement. - Note: Blackwell is rather finicky with benchmarking, and I noticed decent amount of variation between runs. But this looks better so we can just ship it for now. Also remove some unnecessary template code. Reviewed By: jiawenliu64 Differential Revision: D81683544 fbshipit-source-id: b2218db217d812ed79b4f8f28a8a79a5fe13fc38
1 parent b12d4a0 commit 59b2bfd

File tree

3 files changed

+206
-82
lines changed

3 files changed

+206
-82
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2915,6 +2915,50 @@ def cuda(self) -> bool:
29152915
return True
29162916

29172917

2918+
@register_quantize_op
2919+
class BF16GroupedGemm2d3d(QuantizeOpBase):
2920+
"""
2921+
Torch BF16 grouped GEMM with 2D inputs and 3D weights.
2922+
"""
2923+
2924+
def preprocess(self, x, w):
2925+
assert isinstance(x, list)
2926+
assert isinstance(w, list)
2927+
offs = torch.tensor(
2928+
[i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
2929+
)
2930+
offs = torch.cumsum(offs, dim=0).to(torch.int32)
2931+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
2932+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
2933+
return x, w, offs
2934+
2935+
def quantize(self, x, w, offs):
2936+
return x, w, offs
2937+
2938+
def compute(self, x, w, offs):
2939+
return torch._grouped_mm(
2940+
x,
2941+
w.transpose(-2, -1),
2942+
offs=offs,
2943+
)
2944+
2945+
def quantize_and_compute(self, x, w, offs):
2946+
x, w, offs = self.quantize(x, w)
2947+
return self.compute(x, w, offs)
2948+
2949+
@property
2950+
def name(self) -> str:
2951+
return "bf16_baseline_grouped_2d_3d"
2952+
2953+
@property
2954+
def hip(self) -> bool:
2955+
return False
2956+
2957+
@property
2958+
def cuda(self) -> bool:
2959+
return True
2960+
2961+
29182962
@register_quantize_op
29192963
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
29202964
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu

Lines changed: 134 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
2020
// clang-format on
2121

22+
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
23+
#include "fbgemm_gpu/quantize/utils.h"
24+
2225
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2326
#include "mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh"
2427
#endif
@@ -27,83 +30,160 @@ namespace fbgemm_gpu {
2730

2831
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2932

30-
template <typename InputType>
31-
Kernel_mx8mx8bf16_grouped<InputType>
33+
Kernel_mx8mx8bf16_grouped get_kernel_via_tuning(
34+
int M,
35+
int N,
36+
int K,
37+
int G,
38+
at::Tensor XQ,
39+
at::Tensor WQ,
40+
at::Tensor x_scale,
41+
at::Tensor w_scale,
42+
at::Tensor output,
43+
at::Tensor offsets) {
44+
static TuningCache cache("mx8mx8bf16_grouped");
45+
46+
M = nextPowerOf2(M);
47+
N = nextPowerOf2(N);
48+
K = nextPowerOf2(K);
49+
const std::string shape_key =
50+
std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K);
51+
52+
const auto& kernels = get_mx8mx8bf16_grouped_kernels();
53+
auto kernel = cache.findBestKernelMaybeAutotune(
54+
shape_key, kernels, XQ, WQ, x_scale, w_scale, output, G, offsets);
55+
56+
return kernel;
57+
}
58+
59+
Kernel_mx8mx8bf16_grouped
3260
get_kernel_via_heuristics(int M, int N, int K, int G) {
33-
// Llama4 shapes
34-
if (N == 5120 && K == 1024) {
35-
if (G <= 8) {
36-
if (M <= 256) {
61+
if (M <= 128) {
62+
if (N <= 512) {
63+
return mx8mx8bf16_grouped_256_64_256_2_1_1;
64+
} else if (N <= 1024) {
65+
if (K <= 4096) {
3766
return mx8mx8bf16_grouped_256_64_256_2_1_1;
38-
} else if (M <= 512) {
67+
} else {
3968
return mx8mx8bf16_grouped_128_64_256_1_1_1;
40-
} else if (M <= 1024) {
41-
return mx8mx8bf16_grouped_128_128_256_1_1_1;
4269
}
43-
} else if (G <= 16) {
44-
if (M <= 1024) {
45-
return mx8mx8bf16_grouped_128_64_256_1_1_1;
46-
} else if (M <= 2048) {
70+
} else {
71+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
72+
}
73+
} else if (M <= 512) {
74+
if (N <= 512) {
75+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
76+
} else if (N <= 4096) {
77+
if (K <= 1024) {
78+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
79+
} else {
4780
return mx8mx8bf16_grouped_256_128_256_2_1_1;
4881
}
82+
} else if (N <= 8192) {
83+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
4984
} else {
50-
if (M <= 1024) {
51-
return mx8mx8bf16_grouped_256_64_256_2_1_1;
52-
} else if (M <= 4096) {
53-
return mx8mx8bf16_grouped_128_64_256_1_1_1;
54-
} else if (M <= 8192) {
55-
return mx8mx8bf16_grouped_256_64_256_2_1_1;
85+
if (K <= 512) {
86+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
87+
} else if (K <= 4096) {
88+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
89+
} else if (K <= 8192) {
90+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
91+
} else {
92+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
5693
}
5794
}
58-
return mx8mx8bf16_grouped_256_256_256_2_1_1;
59-
} else if (N == 2048 && K == 5120) {
60-
if (G <= 8) {
61-
if (M <= 256) {
62-
return mx8mx8bf16_grouped_256_64_256_2_1_1;
63-
} else if (M <= 512) {
64-
return mx8mx8bf16_grouped_128_64_256_1_1_1;
65-
} else if (M <= 1024) {
66-
return mx8mx8bf16_grouped_128_128_256_1_1_1;
95+
} else if (M <= 1024) {
96+
if (N <= 2048) {
97+
if (K <= 1024) {
98+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
99+
} else {
100+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
67101
}
68-
} else if (G <= 16) {
69-
if (M <= 1024) {
70-
return mx8mx8bf16_grouped_256_64_256_2_1_1;
71-
} else if (M <= 2048) {
72-
return mx8mx8bf16_grouped_128_128_256_1_1_1;
102+
} else if (N <= 4096) {
103+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
104+
} else if (N <= 8192) {
105+
if (K <= 512) {
106+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
107+
} else {
108+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
73109
}
74110
} else {
75-
if (M <= 1024) {
76-
return mx8mx8bf16_grouped_256_64_256_2_1_1;
77-
} else if (M <= 16384) {
111+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
112+
}
113+
} else if (M <= 2048) {
114+
if (N <= 1024) {
115+
if (K <= 1024) {
116+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
117+
} else {
118+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
119+
}
120+
} else if (N <= 2048) {
121+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
122+
} else {
123+
if (K <= 512) {
124+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
125+
} else {
78126
return mx8mx8bf16_grouped_256_128_256_2_1_1;
79127
}
80128
}
81-
return mx8mx8bf16_grouped_256_256_256_2_1_1;
82-
}
83-
84-
// Fallback to legacy heuristic
85-
if (M <= 1000) {
86-
return mx8mx8bf16_grouped_256_128_256_2_1_1;
129+
} else if (M <= 4096) {
130+
if (N <= 512) {
131+
if (K <= 512) {
132+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
133+
} else {
134+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
135+
}
136+
} else if (N <= 1024) {
137+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
138+
} else {
139+
if (K <= 512) {
140+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
141+
} else {
142+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
143+
}
144+
}
145+
} else if (M <= 8192) {
146+
if (K <= 512) {
147+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
148+
} else {
149+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
150+
}
87151
} else {
88-
return mx8mx8bf16_grouped_256_256_256_2_1_1;
152+
if (N <= 8192) {
153+
if (K <= 512) {
154+
return mx8mx8bf16_grouped_256_256_256_2_1_1;
155+
} else {
156+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
157+
}
158+
} else {
159+
if (K <= 512) {
160+
return mx8mx8bf16_grouped_128_64_256_1_1_1;
161+
} else {
162+
return mx8mx8bf16_grouped_256_128_256_2_1_1;
163+
}
164+
}
89165
}
90166
}
91167

92-
template <typename InputType>
93168
at::Tensor dispatch_mx8_grouped_kernel(
94169
int M,
95170
int N,
96171
int K,
97172
int G,
98-
InputType XQ, // FP8
99-
InputType WQ, // FP8
100-
InputType x_scale,
101-
InputType w_scale,
173+
at::Tensor XQ, // FP8
174+
at::Tensor WQ, // FP8
175+
at::Tensor x_scale,
176+
at::Tensor w_scale,
102177
at::Tensor output,
103178
at::Tensor offsets) {
104179
// Select kernel to run via heuristics.
105180
auto kernel = [&]() {
106-
return get_kernel_via_heuristics<InputType>(M, N, K, G);
181+
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
182+
return get_kernel_via_tuning(
183+
M, N, K, G, XQ, WQ, x_scale, w_scale, output, offsets);
184+
} else {
185+
return get_kernel_via_heuristics(M, N, K, G);
186+
}
107187
}();
108188
// Invoke kernel
109189
return kernel(XQ, WQ, x_scale, w_scale, output, G, offsets);
@@ -149,6 +229,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
149229
output_actual.size(1) == N,
150230
"for 2d-3d grouped GEMM, output shape must be (total_M, N).");
151231

232+
// Normalized jagged dim for heuristics
233+
M /= G;
152234
// 2d-2d case.
153235
} else if (XQ.dim() == 2 && WQ.dim() == 2) {
154236
// Alias for clarity that groups are along K dimension for 2d-2d case.
@@ -167,7 +249,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
167249
output_actual.dim() == 3 && output_actual.size(0) == G &&
168250
output_actual.size(1) == M && output_actual.size(2) == N,
169251
"for 2d-2d grouped GEMM, output shape must be (G, M, N).");
170-
252+
// Normalized jagged dim for heuristics
253+
K /= G;
171254
} else {
172255
TORCH_CHECK(false, "Invalid input shapes. Must be one of 2D-2D, 2D-3D.");
173256
}
@@ -178,7 +261,7 @@ at::Tensor mx8mx8bf16_grouped_mm(
178261
}
179262

180263
// Return continuous view of output.
181-
return dispatch_mx8_grouped_kernel<at::Tensor>(
264+
return dispatch_mx8_grouped_kernel(
182265
M, N, K, G, XQ, WQ, x_scale, w_scale, output_actual, offsets);
183266
}
184267

0 commit comments

Comments
 (0)