1919#include < cutlass/epilogue/collective/collective_builder.hpp> // @manual
2020// clang-format on
2121
22+ #include " fbgemm_gpu/quantize/tuning_cache.hpp"
23+ #include " fbgemm_gpu/quantize/utils.h"
24+
2225#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2326#include " mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh"
2427#endif
@@ -27,83 +30,160 @@ namespace fbgemm_gpu {
2730
2831#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2932
30- template <typename InputType>
31- Kernel_mx8mx8bf16_grouped<InputType>
33+ Kernel_mx8mx8bf16_grouped get_kernel_via_tuning (
34+ int M,
35+ int N,
36+ int K,
37+ int G,
38+ at::Tensor XQ,
39+ at::Tensor WQ,
40+ at::Tensor x_scale,
41+ at::Tensor w_scale,
42+ at::Tensor output,
43+ at::Tensor offsets) {
44+ static TuningCache cache (" mx8mx8bf16_grouped" );
45+
46+ M = nextPowerOf2 (M);
47+ N = nextPowerOf2 (N);
48+ K = nextPowerOf2 (K);
49+ const std::string shape_key =
50+ std::to_string (M) + " _" + std::to_string (N) + " _" + std::to_string (K);
51+
52+ const auto & kernels = get_mx8mx8bf16_grouped_kernels ();
53+ auto kernel = cache.findBestKernelMaybeAutotune (
54+ shape_key, kernels, XQ, WQ, x_scale, w_scale, output, G, offsets);
55+
56+ return kernel;
57+ }
58+
59+ Kernel_mx8mx8bf16_grouped
3260get_kernel_via_heuristics (int M, int N, int K, int G) {
33- // Llama4 shapes
34- if (N == 5120 && K == 1024 ) {
35- if (G <= 8 ) {
36- if (M <= 256 ) {
61+ if (M <= 128 ) {
62+ if (N <= 512 ) {
63+ return mx8mx8bf16_grouped_256_64_256_2_1_1;
64+ } else if (N <= 1024 ) {
65+ if (K <= 4096 ) {
3766 return mx8mx8bf16_grouped_256_64_256_2_1_1;
38- } else if (M <= 512 ) {
67+ } else {
3968 return mx8mx8bf16_grouped_128_64_256_1_1_1;
40- } else if (M <= 1024 ) {
41- return mx8mx8bf16_grouped_128_128_256_1_1_1;
4269 }
43- } else if (G <= 16 ) {
44- if (M <= 1024 ) {
45- return mx8mx8bf16_grouped_128_64_256_1_1_1;
46- } else if (M <= 2048 ) {
70+ } else {
71+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
72+ }
73+ } else if (M <= 512 ) {
74+ if (N <= 512 ) {
75+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
76+ } else if (N <= 4096 ) {
77+ if (K <= 1024 ) {
78+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
79+ } else {
4780 return mx8mx8bf16_grouped_256_128_256_2_1_1;
4881 }
82+ } else if (N <= 8192 ) {
83+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
4984 } else {
50- if (M <= 1024 ) {
51- return mx8mx8bf16_grouped_256_64_256_2_1_1;
52- } else if (M <= 4096 ) {
53- return mx8mx8bf16_grouped_128_64_256_1_1_1;
54- } else if (M <= 8192 ) {
55- return mx8mx8bf16_grouped_256_64_256_2_1_1;
85+ if (K <= 512 ) {
86+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
87+ } else if (K <= 4096 ) {
88+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
89+ } else if (K <= 8192 ) {
90+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
91+ } else {
92+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
5693 }
5794 }
58- return mx8mx8bf16_grouped_256_256_256_2_1_1;
59- } else if (N == 2048 && K == 5120 ) {
60- if (G <= 8 ) {
61- if (M <= 256 ) {
62- return mx8mx8bf16_grouped_256_64_256_2_1_1;
63- } else if (M <= 512 ) {
64- return mx8mx8bf16_grouped_128_64_256_1_1_1;
65- } else if (M <= 1024 ) {
66- return mx8mx8bf16_grouped_128_128_256_1_1_1;
95+ } else if (M <= 1024 ) {
96+ if (N <= 2048 ) {
97+ if (K <= 1024 ) {
98+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
99+ } else {
100+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
67101 }
68- } else if (G <= 16 ) {
69- if (M <= 1024 ) {
70- return mx8mx8bf16_grouped_256_64_256_2_1_1;
71- } else if (M <= 2048 ) {
72- return mx8mx8bf16_grouped_128_128_256_1_1_1;
102+ } else if (N <= 4096 ) {
103+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
104+ } else if (N <= 8192 ) {
105+ if (K <= 512 ) {
106+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
107+ } else {
108+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
73109 }
74110 } else {
75- if (M <= 1024 ) {
76- return mx8mx8bf16_grouped_256_64_256_2_1_1;
77- } else if (M <= 16384 ) {
111+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
112+ }
113+ } else if (M <= 2048 ) {
114+ if (N <= 1024 ) {
115+ if (K <= 1024 ) {
116+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
117+ } else {
118+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
119+ }
120+ } else if (N <= 2048 ) {
121+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
122+ } else {
123+ if (K <= 512 ) {
124+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
125+ } else {
78126 return mx8mx8bf16_grouped_256_128_256_2_1_1;
79127 }
80128 }
81- return mx8mx8bf16_grouped_256_256_256_2_1_1;
82- }
83-
84- // Fallback to legacy heuristic
85- if (M <= 1000 ) {
86- return mx8mx8bf16_grouped_256_128_256_2_1_1;
129+ } else if (M <= 4096 ) {
130+ if (N <= 512 ) {
131+ if (K <= 512 ) {
132+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
133+ } else {
134+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
135+ }
136+ } else if (N <= 1024 ) {
137+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
138+ } else {
139+ if (K <= 512 ) {
140+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
141+ } else {
142+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
143+ }
144+ }
145+ } else if (M <= 8192 ) {
146+ if (K <= 512 ) {
147+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
148+ } else {
149+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
150+ }
87151 } else {
88- return mx8mx8bf16_grouped_256_256_256_2_1_1;
152+ if (N <= 8192 ) {
153+ if (K <= 512 ) {
154+ return mx8mx8bf16_grouped_256_256_256_2_1_1;
155+ } else {
156+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
157+ }
158+ } else {
159+ if (K <= 512 ) {
160+ return mx8mx8bf16_grouped_128_64_256_1_1_1;
161+ } else {
162+ return mx8mx8bf16_grouped_256_128_256_2_1_1;
163+ }
164+ }
89165 }
90166}
91167
92- template <typename InputType>
93168at::Tensor dispatch_mx8_grouped_kernel (
94169 int M,
95170 int N,
96171 int K,
97172 int G,
98- InputType XQ, // FP8
99- InputType WQ, // FP8
100- InputType x_scale,
101- InputType w_scale,
173+ at::Tensor XQ, // FP8
174+ at::Tensor WQ, // FP8
175+ at::Tensor x_scale,
176+ at::Tensor w_scale,
102177 at::Tensor output,
103178 at::Tensor offsets) {
104179 // Select kernel to run via heuristics.
105180 auto kernel = [&]() {
106- return get_kernel_via_heuristics<InputType>(M, N, K, G);
181+ if (std::getenv (" FBGEMM_AUTOTUNE_ENABLE" )) {
182+ return get_kernel_via_tuning (
183+ M, N, K, G, XQ, WQ, x_scale, w_scale, output, offsets);
184+ } else {
185+ return get_kernel_via_heuristics (M, N, K, G);
186+ }
107187 }();
108188 // Invoke kernel
109189 return kernel (XQ, WQ, x_scale, w_scale, output, G, offsets);
@@ -149,6 +229,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
149229 output_actual.size (1 ) == N,
150230 " for 2d-3d grouped GEMM, output shape must be (total_M, N)." );
151231
232+ // Normalized jagged dim for heuristics
233+ M /= G;
152234 // 2d-2d case.
153235 } else if (XQ.dim () == 2 && WQ.dim () == 2 ) {
154236 // Alias for clarity that groups are along K dimension for 2d-2d case.
@@ -167,7 +249,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
167249 output_actual.dim () == 3 && output_actual.size (0 ) == G &&
168250 output_actual.size (1 ) == M && output_actual.size (2 ) == N,
169251 " for 2d-2d grouped GEMM, output shape must be (G, M, N)." );
170-
252+ // Normalized jagged dim for heuristics
253+ K /= G;
171254 } else {
172255 TORCH_CHECK (false , " Invalid input shapes. Must be one of 2D-2D, 2D-3D." );
173256 }
@@ -178,7 +261,7 @@ at::Tensor mx8mx8bf16_grouped_mm(
178261 }
179262
180263 // Return continuous view of output.
181- return dispatch_mx8_grouped_kernel<at::Tensor> (
264+ return dispatch_mx8_grouped_kernel (
182265 M, N, K, G, XQ, WQ, x_scale, w_scale, output_actual, offsets);
183266}
184267
0 commit comments