diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index cdc8494e45..f57db80c94 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -2915,6 +2915,50 @@ def cuda(self) -> bool: return True +@register_quantize_op +class BF16GroupedGemm2d3d(QuantizeOpBase): + """ + Torch BF16 grouped GEMM with 2D inputs and 3D weights. + """ + + def preprocess(self, x, w): + assert isinstance(x, list) + assert isinstance(w, list) + offs = torch.tensor( + [i.shape[0] for i in x], dtype=torch.int32, device=x[0].device + ) + offs = torch.cumsum(offs, dim=0).to(torch.int32) + x = torch.cat(x, dim=0).contiguous() # (G * M, K) + w = torch.stack(w, dim=0).contiguous() # (G, N, K) + return x, w, offs + + def quantize(self, x, w, offs): + return x, w, offs + + def compute(self, x, w, offs): + return torch._grouped_mm( + x, + w.transpose(-2, -1), + offs=offs, + ) + + def quantize_and_compute(self, x, w, offs): + x, w, offs = self.quantize(x, w) + return self.compute(x, w, offs) + + @property + def name(self) -> str: + return "bf16_baseline_grouped_2d_3d" + + @property + def hip(self) -> bool: + return False + + @property + def cuda(self) -> bool: + return True + + @register_quantize_op class MXFP8GroupedGemm2d3d(QuantizeOpBase): """ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu index cae9addb80..533378a3b1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu @@ -19,6 +19,9 @@ #include // @manual // clang-format on +#include "fbgemm_gpu/quantize/tuning_cache.hpp" +#include "fbgemm_gpu/quantize/utils.h" + #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) #include "mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh" #endif @@ -27,83 +30,160 @@ namespace fbgemm_gpu { #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) -template -Kernel_mx8mx8bf16_grouped +Kernel_mx8mx8bf16_grouped get_kernel_via_tuning( + int M, + int N, + int K, + int G, + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, + at::Tensor offsets) { + static TuningCache cache("mx8mx8bf16_grouped"); + + M = nextPowerOf2(M); + N = nextPowerOf2(N); + K = nextPowerOf2(K); + const std::string shape_key = + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); + + const auto& kernels = get_mx8mx8bf16_grouped_kernels(); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, kernels, XQ, WQ, x_scale, w_scale, output, G, offsets); + + return kernel; +} + +Kernel_mx8mx8bf16_grouped get_kernel_via_heuristics(int M, int N, int K, int G) { - // Llama4 shapes - if (N == 5120 && K == 1024) { - if (G <= 8) { - if (M <= 256) { + if (M <= 128) { + if (N <= 512) { + return mx8mx8bf16_grouped_256_64_256_2_1_1; + } else if (N <= 1024) { + if (K <= 4096) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (M <= 512) { + } else { return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (M <= 1024) { - return mx8mx8bf16_grouped_128_128_256_1_1_1; } - } else if (G <= 16) { - if (M <= 1024) { - return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (M <= 2048) { + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else if (M <= 512) { + if (N <= 512) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else if (N <= 4096) { + if (K <= 1024) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { return mx8mx8bf16_grouped_256_128_256_2_1_1; } + } else if (N <= 8192) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; } else { - if (M <= 1024) { - return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (M <= 4096) { - return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (M <= 8192) { - return mx8mx8bf16_grouped_256_64_256_2_1_1; + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else if (K <= 4096) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else if (K <= 8192) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; } } - return mx8mx8bf16_grouped_256_256_256_2_1_1; - } else if (N == 2048 && K == 5120) { - if (G <= 8) { - if (M <= 256) { - return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (M <= 512) { - return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (M <= 1024) { - return mx8mx8bf16_grouped_128_128_256_1_1_1; + } else if (M <= 1024) { + if (N <= 2048) { + if (K <= 1024) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; } - } else if (G <= 16) { - if (M <= 1024) { - return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (M <= 2048) { - return mx8mx8bf16_grouped_128_128_256_1_1_1; + } else if (N <= 4096) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else if (N <= 8192) { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; } } else { - if (M <= 1024) { - return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (M <= 16384) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else if (M <= 2048) { + if (N <= 1024) { + if (K <= 1024) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else if (N <= 2048) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { return mx8mx8bf16_grouped_256_128_256_2_1_1; } } - return mx8mx8bf16_grouped_256_256_256_2_1_1; - } - - // Fallback to legacy heuristic - if (M <= 1000) { - return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else if (M <= 4096) { + if (N <= 512) { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else if (N <= 1024) { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } else { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } + } else if (M <= 8192) { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } } else { - return mx8mx8bf16_grouped_256_256_256_2_1_1; + if (N <= 8192) { + if (K <= 512) { + return mx8mx8bf16_grouped_256_256_256_2_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } else { + if (K <= 512) { + return mx8mx8bf16_grouped_128_64_256_1_1_1; + } else { + return mx8mx8bf16_grouped_256_128_256_2_1_1; + } + } } } -template at::Tensor dispatch_mx8_grouped_kernel( int M, int N, int K, int G, - InputType XQ, // FP8 - InputType WQ, // FP8 - InputType x_scale, - InputType w_scale, + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, at::Tensor output, at::Tensor offsets) { // Select kernel to run via heuristics. auto kernel = [&]() { - return get_kernel_via_heuristics(M, N, K, G); + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { + return get_kernel_via_tuning( + M, N, K, G, XQ, WQ, x_scale, w_scale, output, offsets); + } else { + return get_kernel_via_heuristics(M, N, K, G); + } }(); // Invoke kernel return kernel(XQ, WQ, x_scale, w_scale, output, G, offsets); @@ -149,6 +229,8 @@ at::Tensor mx8mx8bf16_grouped_mm( output_actual.size(1) == N, "for 2d-3d grouped GEMM, output shape must be (total_M, N)."); + // Normalized jagged dim for heuristics + M /= G; // 2d-2d case. } else if (XQ.dim() == 2 && WQ.dim() == 2) { // Alias for clarity that groups are along K dimension for 2d-2d case. @@ -167,7 +249,8 @@ at::Tensor mx8mx8bf16_grouped_mm( output_actual.dim() == 3 && output_actual.size(0) == G && output_actual.size(1) == M && output_actual.size(2) == N, "for 2d-2d grouped GEMM, output shape must be (G, M, N)."); - + // Normalized jagged dim for heuristics + K /= G; } else { TORCH_CHECK(false, "Invalid input shapes. Must be one of 2D-2D, 2D-3D."); } @@ -178,7 +261,7 @@ at::Tensor mx8mx8bf16_grouped_mm( } // Return continuous view of output. - return dispatch_mx8_grouped_kernel( + return dispatch_mx8_grouped_kernel( M, N, K, G, XQ, WQ, x_scale, w_scale, output_actual, offsets); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh index 86c6946895..d35d70c157 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh @@ -13,8 +13,8 @@ namespace fbgemm_gpu { #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, at::Tensor output, @@ -22,8 +22,8 @@ at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, at::Tensor output, @@ -31,8 +31,8 @@ at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, at::Tensor output, @@ -40,8 +40,8 @@ at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, at::Tensor output, @@ -49,41 +49,38 @@ at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, at::Tensor output, int64_t G, at::Tensor offsets); -template using Kernel_mx8mx8bf16_grouped = at::Tensor (*)( - InputType, // XQ - InputType, // WQ - InputType, // x_scale - InputType, // w_scale + at::Tensor, // XQ + at::Tensor, // WQ + at::Tensor, // x_scale + at::Tensor, // w_scale at::Tensor, // output int64_t, // G at::Tensor); // offsets -template -const std::unordered_map>& +const std::unordered_map& get_mx8mx8bf16_grouped_kernels() { - static const std:: - unordered_map> - kernels = { - {"mx8mx8bf16_grouped_128_64_256_1_1_1", - mx8mx8bf16_grouped_128_64_256_1_1_1}, - {"mx8mx8bf16_grouped_128_128_256_1_1_1", - mx8mx8bf16_grouped_128_128_256_1_1_1}, - {"mx8mx8bf16_grouped_256_64_256_2_1_1", - mx8mx8bf16_grouped_256_64_256_2_1_1}, - {"mx8mx8bf16_grouped_256_128_256_2_1_1", - mx8mx8bf16_grouped_256_128_256_2_1_1}, - {"mx8mx8bf16_grouped_256_256_256_2_1_1", - mx8mx8bf16_grouped_256_256_256_2_1_1}, - }; + static const std::unordered_map + kernels = { + {"mx8mx8bf16_grouped_128_64_256_1_1_1", + mx8mx8bf16_grouped_128_64_256_1_1_1}, + {"mx8mx8bf16_grouped_128_128_256_1_1_1", + mx8mx8bf16_grouped_128_128_256_1_1_1}, + {"mx8mx8bf16_grouped_256_64_256_2_1_1", + mx8mx8bf16_grouped_256_64_256_2_1_1}, + {"mx8mx8bf16_grouped_256_128_256_2_1_1", + mx8mx8bf16_grouped_256_128_256_2_1_1}, + {"mx8mx8bf16_grouped_256_256_256_2_1_1", + mx8mx8bf16_grouped_256_256_256_2_1_1}, + }; return kernels; }