Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,50 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16GroupedGemm2d3d(QuantizeOpBase):
"""
Torch BF16 grouped GEMM with 2D inputs and 3D weights.
"""

def preprocess(self, x, w):
assert isinstance(x, list)
assert isinstance(w, list)
offs = torch.tensor(
[i.shape[0] for i in x], dtype=torch.int32, device=x[0].device
)
offs = torch.cumsum(offs, dim=0).to(torch.int32)
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
return x, w, offs

def quantize(self, x, w, offs):
return x, w, offs

def compute(self, x, w, offs):
return torch._grouped_mm(
x,
w.transpose(-2, -1),
offs=offs,
)

def quantize_and_compute(self, x, w, offs):
x, w, offs = self.quantize(x, w)
return self.compute(x, w, offs)

@property
def name(self) -> str:
return "bf16_baseline_grouped_2d_3d"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
// clang-format on

#include "fbgemm_gpu/quantize/tuning_cache.hpp"
#include "fbgemm_gpu/quantize/utils.h"

#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
#include "mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh"
#endif
Expand All @@ -27,83 +30,160 @@ namespace fbgemm_gpu {

#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)

template <typename InputType>
Kernel_mx8mx8bf16_grouped<InputType>
Kernel_mx8mx8bf16_grouped get_kernel_via_tuning(
int M,
int N,
int K,
int G,
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
at::Tensor offsets) {
static TuningCache cache("mx8mx8bf16_grouped");

M = nextPowerOf2(M);
N = nextPowerOf2(N);
K = nextPowerOf2(K);
const std::string shape_key =
std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K);

const auto& kernels = get_mx8mx8bf16_grouped_kernels();
auto kernel = cache.findBestKernelMaybeAutotune(
shape_key, kernels, XQ, WQ, x_scale, w_scale, output, G, offsets);

return kernel;
}

Kernel_mx8mx8bf16_grouped
get_kernel_via_heuristics(int M, int N, int K, int G) {
// Llama4 shapes
if (N == 5120 && K == 1024) {
if (G <= 8) {
if (M <= 256) {
if (M <= 128) {
if (N <= 512) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (N <= 1024) {
if (K <= 4096) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (M <= 512) {
} else {
return mx8mx8bf16_grouped_128_64_256_1_1_1;
} else if (M <= 1024) {
return mx8mx8bf16_grouped_128_128_256_1_1_1;
}
} else if (G <= 16) {
if (M <= 1024) {
return mx8mx8bf16_grouped_128_64_256_1_1_1;
} else if (M <= 2048) {
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (M <= 512) {
if (N <= 512) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else if (N <= 4096) {
if (K <= 1024) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (N <= 8192) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else {
if (M <= 1024) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (M <= 4096) {
return mx8mx8bf16_grouped_128_64_256_1_1_1;
} else if (M <= 8192) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else if (K <= 4096) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else if (K <= 8192) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
}
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else if (N == 2048 && K == 5120) {
if (G <= 8) {
if (M <= 256) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (M <= 512) {
return mx8mx8bf16_grouped_128_64_256_1_1_1;
} else if (M <= 1024) {
return mx8mx8bf16_grouped_128_128_256_1_1_1;
} else if (M <= 1024) {
if (N <= 2048) {
if (K <= 1024) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (G <= 16) {
if (M <= 1024) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (M <= 2048) {
return mx8mx8bf16_grouped_128_128_256_1_1_1;
} else if (N <= 4096) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else if (N <= 8192) {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else {
if (M <= 1024) {
return mx8mx8bf16_grouped_256_64_256_2_1_1;
} else if (M <= 16384) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (M <= 2048) {
if (N <= 1024) {
if (K <= 1024) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (N <= 2048) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
}
return mx8mx8bf16_grouped_256_256_256_2_1_1;
}

// Fallback to legacy heuristic
if (M <= 1000) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else if (M <= 4096) {
if (N <= 512) {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else if (N <= 1024) {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
} else {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
}
} else if (M <= 8192) {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
if (N <= 8192) {
if (K <= 512) {
return mx8mx8bf16_grouped_256_256_256_2_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
} else {
if (K <= 512) {
return mx8mx8bf16_grouped_128_64_256_1_1_1;
} else {
return mx8mx8bf16_grouped_256_128_256_2_1_1;
}
}
}
}

template <typename InputType>
at::Tensor dispatch_mx8_grouped_kernel(
int M,
int N,
int K,
int G,
InputType XQ, // FP8
InputType WQ, // FP8
InputType x_scale,
InputType w_scale,
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
at::Tensor offsets) {
// Select kernel to run via heuristics.
auto kernel = [&]() {
return get_kernel_via_heuristics<InputType>(M, N, K, G);
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
return get_kernel_via_tuning(
M, N, K, G, XQ, WQ, x_scale, w_scale, output, offsets);
} else {
return get_kernel_via_heuristics(M, N, K, G);
}
}();
// Invoke kernel
return kernel(XQ, WQ, x_scale, w_scale, output, G, offsets);
Expand Down Expand Up @@ -149,6 +229,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
output_actual.size(1) == N,
"for 2d-3d grouped GEMM, output shape must be (total_M, N).");

// Normalized jagged dim for heuristics
M /= G;
// 2d-2d case.
} else if (XQ.dim() == 2 && WQ.dim() == 2) {
// Alias for clarity that groups are along K dimension for 2d-2d case.
Expand All @@ -167,7 +249,8 @@ at::Tensor mx8mx8bf16_grouped_mm(
output_actual.dim() == 3 && output_actual.size(0) == G &&
output_actual.size(1) == M && output_actual.size(2) == N,
"for 2d-2d grouped GEMM, output shape must be (G, M, N).");

// Normalized jagged dim for heuristics
K /= G;
} else {
TORCH_CHECK(false, "Invalid input shapes. Must be one of 2D-2D, 2D-3D.");
}
Expand All @@ -178,7 +261,7 @@ at::Tensor mx8mx8bf16_grouped_mm(
}

// Return continuous view of output.
return dispatch_mx8_grouped_kernel<at::Tensor>(
return dispatch_mx8_grouped_kernel(
M, N, K, G, XQ, WQ, x_scale, w_scale, output_actual, offsets);
}

Expand Down
Loading
Loading