diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 540dbdc692..cdc8494e45 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -2916,78 +2916,194 @@ def cuda(self) -> bool: @register_quantize_op -class MXFP8StackedGroupedGemm(QuantizeOpBase): +class MXFP8GroupedGemm2d3d(QuantizeOpBase): """ - MXFP8 grouped matmul with blockwise scaling and stacked inputs. + MXFP8 grouped GEMM with 2D inputs and 3D weights. """ def preprocess(self, x, w): - m_values = [i.shape[0] for i in x] - m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device) + assert isinstance(x, list) + assert isinstance(w, list) + x = torch.cat(x, dim=0).contiguous() # (G * M, K) + w = torch.stack(w, dim=0).contiguous() # (G, N, K) + return x, w + + def quantize(self, x, w): + block_size = 32 + G, N, K = w.shape + total_M = x.shape[0] + group_size = total_M // G + input_group_end_offsets = torch.arange( + group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device + ) + + # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. wq_list = [] w_scale_list = [] - for i in range(m_sizes.shape[0]): + for i in range(G): w_scale, wq = to_mxfp8(w[i]) w_scale = _to_blocked(w_scale) wq_list.append(wq) w_scale_list.append(w_scale) wq = torch.stack(wq_list, dim=0).contiguous() w_scale = torch.stack(w_scale_list, dim=0).contiguous() - return x, wq, w_scale, m_sizes - def quantize(self, x, wq, w_scale, m_sizes): - starting_row_after_padding_list = [0] + # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. xq_list = [] x_scale_list = [] - for i in range(m_sizes.shape[0]): - scale_slice = x[i] - if m_sizes[i].item() != 0: - x_scale, xq = to_mxfp8(scale_slice) + for i in range(G): + prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1] + curr_group_end = input_group_end_offsets[i] + group_size = curr_group_end - prev_group_end + if group_size > 0: + x_slice = x[prev_group_end:curr_group_end, :] + x_scale, xq = to_mxfp8(x_slice) x_scale = _to_blocked(x_scale) xq_list.append(xq) x_scale_list.append(x_scale) - starting_row_after_padding_list.append( - starting_row_after_padding_list[i] - + x_scale.numel() // (x[0].shape[1] // 32) - ) - else: - starting_row_after_padding_list.append( - starting_row_after_padding_list[i] - ) xq = torch.cat(xq_list, dim=0).contiguous() x_scale = torch.cat(x_scale_list, dim=0).contiguous() - x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32) + x_scale = x_scale.reshape(-1, K // block_size) xq = xq.view(-1, xq.shape[-1]) - return ( + return xq, wq, x_scale, w_scale, input_group_end_offsets + + def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets): + return torch.ops.fbgemm.mx8mx8bf16_grouped_mm( xq, - wq, + wq.transpose(-2, -1), x_scale, w_scale, - m_sizes, - torch.tensor(starting_row_after_padding_list, device=xq.device), + input_group_end_offsets, ) - def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding): - return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked( + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w) + return self.compute( xq, wq, x_scale, w_scale, - m_sizes, - starting_row_after_padding=starting_row_after_padding, + input_group_end_offsets, ) - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize( - x, w + @property + def name(self) -> str: + return "cutlass_mx8mx8bf16_grouped_mm_2d_3d" + + @property + def hip(self) -> bool: + return False + + @property + def cuda(self) -> bool: + return True + + +@register_quantize_op +class MXFP8GroupedGemm2d2d(QuantizeOpBase): + """ + MXFP8 grouped GEMM with 2D inputs and 3D weights. + """ + + def preprocess(self, x, w): + assert isinstance(x, list) + assert isinstance(w, list) + G = len(x) + x = torch.cat(x, dim=1).contiguous() # (M, total_K) + w = torch.cat(w, dim=1).contiguous() # (N, total_K) + return x, w, G + + def quantize(self, x, w, G): + # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`, + # where we use "K" as the contracting dim which has "G" groups. + M, total_K = x.shape + N, _ = w.shape + group_size = total_K // G + input_group_end_offsets = torch.arange( + group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device + ) + + # Convert scales to blocked format. + x_list = [] + w_list = [] + x_blocked_scale_list = [] + w_blocked_scale_list = [] + + def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + for group_idx in range(G): + # to_mxfp8 per group + prev_group_end_offset = ( + 0 if group_idx == 0 else input_group_end_offsets[group_idx - 1] + ) + curr_group_end_offset = input_group_end_offsets[group_idx] + group_size = curr_group_end_offset - prev_group_end_offset + if group_size > 0: + x_slice = x[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (M, K_group) + w_slice = w[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (N, K_group) + x_scale_slice, xq_slice = to_mxfp8( + x_slice + ) # scale shape -> (M, K_group // 32) + w_scale_slice, wq_slice = to_mxfp8( + w_slice + ) # scale shape -> (N, K_group // 32) + x_list.append(xq_slice) + w_list.append(wq_slice) + + # Convert scales to blocked format. + x_scale_slice_blocked = _to_blocked( + x_scale_slice + ) # (round_up(M, 128), round_up(K_group//32, 4)) + w_scale_slice_blocked = _to_blocked( + w_scale_slice + ) # (round_up(N, 128), round_up(K_group//32, 4)) + x_blocked_scale_list.append(x_scale_slice_blocked) + w_blocked_scale_list.append(w_scale_slice_blocked) + + # Assemble the full XQ and WQ + xq = torch.cat(x_list, dim=1).contiguous() + wq = torch.cat(w_list, dim=1).contiguous() + + # Combine all XQ groups blocked scales into one tensor. + x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0) + M_rounded = round_up(M, 128) + x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1) + + # Combine all WQ groups blocked scales into one tensor. + w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0) + N_rounded = round_up(N, 128) + w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1) + return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets + + def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets): + return torch.ops.fbgemm.mx8mx8bf16_grouped_mm( + xq, + wq.transpose(-2, -1), + x_scale, + w_scale, + input_group_end_offsets, ) + + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w) return self.compute( - xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding + xq, + wq, + x_scale, + w_scale, + input_group_end_offsets, ) @property def name(self) -> str: - return "cutlass_mx8mx8bf16_grouped_stacked" + return "cutlass_mx8mx8bf16_grouped_mm_2d_2d" @property def hip(self) -> bool: diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu index 693c05739a..cae9addb80 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped.cu @@ -29,52 +29,52 @@ namespace fbgemm_gpu { template Kernel_mx8mx8bf16_grouped -get_kernel_via_heuristics(int total_M, int N, int K, int G) { +get_kernel_via_heuristics(int M, int N, int K, int G) { // Llama4 shapes if (N == 5120 && K == 1024) { if (G <= 8) { - if (total_M <= 256) { + if (M <= 256) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (total_M <= 512) { + } else if (M <= 512) { return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (total_M <= 1024) { + } else if (M <= 1024) { return mx8mx8bf16_grouped_128_128_256_1_1_1; } } else if (G <= 16) { - if (total_M <= 1024) { + if (M <= 1024) { return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (total_M <= 2048) { + } else if (M <= 2048) { return mx8mx8bf16_grouped_256_128_256_2_1_1; } } else { - if (total_M <= 1024) { + if (M <= 1024) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (total_M <= 4096) { + } else if (M <= 4096) { return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (total_M <= 8192) { + } else if (M <= 8192) { return mx8mx8bf16_grouped_256_64_256_2_1_1; } } return mx8mx8bf16_grouped_256_256_256_2_1_1; } else if (N == 2048 && K == 5120) { if (G <= 8) { - if (total_M <= 256) { + if (M <= 256) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (total_M <= 512) { + } else if (M <= 512) { return mx8mx8bf16_grouped_128_64_256_1_1_1; - } else if (total_M <= 1024) { + } else if (M <= 1024) { return mx8mx8bf16_grouped_128_128_256_1_1_1; } } else if (G <= 16) { - if (total_M <= 1024) { + if (M <= 1024) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (total_M <= 2048) { + } else if (M <= 2048) { return mx8mx8bf16_grouped_128_128_256_1_1_1; } } else { - if (total_M <= 1024) { + if (M <= 1024) { return mx8mx8bf16_grouped_256_64_256_2_1_1; - } else if (total_M <= 16384) { + } else if (M <= 16384) { return mx8mx8bf16_grouped_256_128_256_2_1_1; } } @@ -82,7 +82,7 @@ get_kernel_via_heuristics(int total_M, int N, int K, int G) { } // Fallback to legacy heuristic - if (total_M <= 1000) { + if (M <= 1000) { return mx8mx8bf16_grouped_256_128_256_2_1_1; } else { return mx8mx8bf16_grouped_256_256_256_2_1_1; @@ -91,7 +91,7 @@ get_kernel_via_heuristics(int total_M, int N, int K, int G) { template at::Tensor dispatch_mx8_grouped_kernel( - int total_M, + int M, int N, int K, int G, @@ -100,86 +100,101 @@ at::Tensor dispatch_mx8_grouped_kernel( InputType x_scale, InputType w_scale, at::Tensor output, - std::optional zero_start_index_M = std::nullopt, - std::optional M_sizes = std::nullopt, - std::optional starting_row_after_padding = std::nullopt) { - TORCH_CHECK( - zero_start_index_M.has_value() != M_sizes.has_value(), - "One of zero_start_index_M or M_sizes must be provided."); - TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided."); - TORCH_CHECK( - starting_row_after_padding.has_value(), - "starting_row_after_padding is assumed to be provided."); - at::Tensor starting_row_after_padding_actual = - starting_row_after_padding.value_or(at::zeros({0})); - TORCH_CHECK(starting_row_after_padding_actual.size(0) % (G + 1) == 0); - + at::Tensor offsets) { // Select kernel to run via heuristics. auto kernel = [&]() { - return get_kernel_via_heuristics(total_M, N, K, G); + return get_kernel_via_heuristics(M, N, K, G); }(); // Invoke kernel - return kernel( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + return kernel(XQ, WQ, x_scale, w_scale, output, G, offsets); } -at::Tensor mx8mx8bf16_grouped_stacked( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 +at::Tensor mx8mx8bf16_grouped_mm( + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, - at::Tensor M_sizes, - std::optional starting_row_after_padding = std::nullopt) { - int64_t total_M = XQ.size(0); - int64_t N = WQ.size(1); - int64_t K = WQ.size(2); - int64_t G = M_sizes.size(0); - TORCH_CHECK( - M_sizes.device() == XQ.device(), - "M_sizes must be on same device as inputs."); - TORCH_CHECK( - WQ.dim() == 3 && WQ.size(0) == G, "Weights should be shape [G, N, K].") - at::Tensor Y = at::empty({total_M, N}, XQ.options().dtype(at::kBFloat16)); + at::Tensor offsets, + std::optional output) { + TORCH_CHECK(offsets.dtype() == at::kInt, "offsets must be int32."); + TORCH_CHECK(offsets.dim() == 1, "offsets must be 1D tensor."); + TORCH_CHECK(XQ.is_contiguous(), "XQ must be row major."); + TORCH_CHECK(WQ.transpose(-2, -1).is_contiguous(), "WQ must be column major."); + TORCH_CHECK(x_scale.is_contiguous(), "x_scale must be contiguous."); + TORCH_CHECK(w_scale.is_contiguous(), "w_scale must be contiguous."); + + int64_t G = offsets.size(0); + int64_t M = XQ.size(0); + int64_t N = WQ.size(-1); + int64_t K = WQ.size(-2); + + at::Tensor output_actual; + + // 2d-3d case. + if (XQ.dim() == 2 && WQ.dim() == 3) { + // Alias for clarity that groups are along M dimension for 2d-3d case. + int64_t total_M = M; + + // Allocate output tensor if necessary. + output_actual = output.has_value() + ? output.value() + : at::empty({total_M, N}, XQ.options().dtype(at::kBFloat16)); + + TORCH_CHECK( + XQ.size(-1) == K && WQ.size(0) == G, + "for 2d-3d grouped GEMM, XQ shape must be (total_M, K) and WQ shape must be (G, K, N)."); + + TORCH_CHECK( + output_actual.dim() == 2 && output_actual.size(0) == total_M && + output_actual.size(1) == N, + "for 2d-3d grouped GEMM, output shape must be (total_M, N)."); + + // 2d-2d case. + } else if (XQ.dim() == 2 && WQ.dim() == 2) { + // Alias for clarity that groups are along K dimension for 2d-2d case. + int64_t total_K = K; + + // Allocate output tensor if necessary. + output_actual = output.has_value() + ? output.value() + : at::empty({G, M, N}, XQ.options().dtype(at::kBFloat16)); + + TORCH_CHECK( + XQ.dim() == 2 && WQ.dim() == 2 && WQ.size(-2) == total_K, + "for 2d-2d grouped GEMM, XQ shape must be (M, total_K) and WQ shape must be (total_K, N)."); + + TORCH_CHECK( + output_actual.dim() == 3 && output_actual.size(0) == G && + output_actual.size(1) == M && output_actual.size(2) == N, + "for 2d-2d grouped GEMM, output shape must be (G, M, N)."); + + } else { + TORCH_CHECK(false, "Invalid input shapes. Must be one of 2D-2D, 2D-3D."); + } + // Early exit for empty inputs. - if (total_M == 0) { - return Y; + if (M == 0) { + return output_actual; } + // Return continuous view of output. return dispatch_mx8_grouped_kernel( - total_M, - N, - K, - G, - XQ, - WQ, - x_scale, - w_scale, - Y, - std::nullopt, - M_sizes, - starting_row_after_padding); + M, N, K, G, XQ, WQ, x_scale, w_scale, output_actual, offsets); } #else -at::Tensor mx8mx8bf16_grouped_stacked( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 +at::Tensor mx8mx8bf16_grouped_mm( + at::Tensor XQ, + at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, - at::Tensor M_sizes, - std::optional starting_row_after_padding = std::nullopt) { + at::Tensor offsets, + std::optional output) { throw std::runtime_error( "CUDA version is older than 12.8"); // requires CUDA>=12.8 } + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu index 778414b91a..2c0682bf7b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_128_256_1_1_1.cu @@ -19,19 +19,9 @@ at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { return mx8mx8bf16_grouped_impl( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + XQ, WQ, x_scale, w_scale, output, G, offsets); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu index ae8b8a5e08..88003a63ef 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_128_64_256_1_1_1.cu @@ -19,19 +19,9 @@ at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { return mx8mx8bf16_grouped_impl( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + XQ, WQ, x_scale, w_scale, output, G, offsets); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu index 7142ef01c5..35d5d25d8e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_128_256_2_1_1.cu @@ -19,19 +19,9 @@ at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { return mx8mx8bf16_grouped_impl( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + XQ, WQ, x_scale, w_scale, output, G, offsets); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu index f9e4444603..c6c58b510a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_256_256_2_1_1.cu @@ -19,19 +19,9 @@ at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { return mx8mx8bf16_grouped_impl( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + XQ, WQ, x_scale, w_scale, output, G, offsets); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu index 601e0904ff..7c18a46aaf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_256_64_256_2_1_1.cu @@ -19,19 +19,9 @@ at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { return mx8mx8bf16_grouped_impl( - XQ, - WQ, - x_scale, - w_scale, - output, - G, - zero_start_index_M, - M_sizes, - starting_row_after_padding); + XQ, WQ, x_scale, w_scale, output, G, offsets); } #endif diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh index 8e2cc7f008..154098ee34 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_common.cuh @@ -21,6 +21,13 @@ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) +enum GroupedGemmInputType { + // K dynamic + _2D2D, + // M dynamic (MoE style) + _2D3D +}; + inline int64_t _byte_align(int64_t offset) { int64_t remainder = offset % 16; if (remainder != 0) { @@ -34,156 +41,252 @@ template < typename ElementA, typename ElementB, typename ElementC, - typename ElementComputeEpilogue, + typename ScaleDtype, typename StrideA, typename StrideB, typename StrideC, typename LayoutSFA, typename LayoutSFB, - typename ElementGlobalScale, - typename Sm1xxBlkScaledConfig> -__global__ void set_kernel_args_kernel( - int i, // Group index - int64_t G, // Total groups. - int64_t M, - int64_t N, - int64_t K, - ProblemShape* problem_shape_ptr, - ElementA* xq, - const ElementA** xq_ptr, - ElementB* wq, - const ElementB** wq_ptr, - ElementComputeEpilogue* x_scale, - const ElementComputeEpilogue** x_scale_ptr, - ElementComputeEpilogue* w_scale, - const ElementComputeEpilogue** w_scale_ptr, - ElementC* output, - ElementC** output_ptr, - StrideA* stride_a_ptr, - StrideB* stride_b_ptr, - StrideC* stride_c_ptr, - LayoutSFA* layout_SFA, - LayoutSFB* layout_SFB) { - uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - // Each kernel annoyingly can only set the kernel args for one group. - // This could only be avoided with complicated memory management. - if (idx == 0) { - problem_shape_ptr[i] = ProblemShape(N, M, K); - xq_ptr[i] = xq; - wq_ptr[i] = wq; - x_scale_ptr[i] = x_scale; - w_scale_ptr[i] = w_scale; - output_ptr[i] = output; - stride_a_ptr[i] = cutlass::make_cute_packed_stride( - StrideA{}, cute::make_shape(int(M), int(K), 1)); - stride_b_ptr[i] = cutlass::make_cute_packed_stride( - StrideB{}, cute::make_shape(int(N), int(K), 1)); - stride_c_ptr[i] = cutlass::make_cute_packed_stride( - StrideC{}, cute::make_shape(int(N), int(M), 1)); - layout_SFA[i] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( - cute::make_shape(int(M), int(N), int(K), 1)); - layout_SFB[i] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( - cute::make_shape(int(M), int(N), int(K), 1)); - } -} - -template < - typename ProblemShape, - typename ElementA, - typename ElementB, - typename ElementC, - typename ElementComputeEpilogue, - typename StrideA, - typename StrideB, - typename StrideC, - typename LayoutSFA, - typename LayoutSFB, - typename ElementGlobalScale, typename Sm1xxBlkScaledConfig> __global__ void set_stacked_kernel_args_kernel( int64_t G, + int64_t M, int64_t N, int64_t K, - int64_t num_x_scale_per_group, - int64_t num_w_scale_per_group, ProblemShape* problem_shape_ptr, ElementA* xq, const ElementA** xq_ptr, ElementB* wq, const ElementB** wq_ptr, - ElementComputeEpilogue* x_scale, - const ElementComputeEpilogue** x_scale_ptr, - ElementComputeEpilogue* w_scale, - const ElementComputeEpilogue** w_scale_ptr, + ScaleDtype* x_scale, + const ScaleDtype** x_scale_ptr, + ScaleDtype* w_scale, + const ScaleDtype** w_scale_ptr, ElementC* output, ElementC** output_ptr, StrideA* stride_a_ptr, StrideB* stride_b_ptr, StrideC* stride_c_ptr, - int64_t* M_sizes, + int32_t* offsets, // Group end offsets LayoutSFA* layout_SFA, LayoutSFB* layout_SFB, - int64_t* starting_row_after_padding) { + GroupedGemmInputType gemm_type) { uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x; // If this thread corresponds to a valid group, write kernel args to device // memory. if (group_index < G) { - // Its possible that we're only writing a subset of the groups to - // kernel args. To do this, we need to set all groups initially to empty. - // and keep a problem counter for the number of non-empty groups. - __shared__ int non_zero_counter; - // Initialize counter in first group. - if (group_index == 0) { - non_zero_counter = 0; - } // Set problem shapes to empty by default. problem_shape_ptr[group_index] = ProblemShape(0, 0, 0); - // Sync threads to get consistent state in the block. - __syncthreads(); - - // Compute shape for this group. - // M for this group is pulled directly from M_sizes. - int M = M_sizes[group_index]; - // Only proceed to writing kernel args if this group is non-empty. - if (M > 0) { - // Get the index for this group atomically. - int non_zero_idx = atomicAdd(&non_zero_counter, 1); - // We compute the offset by getting the cumulative sum over - // prior groups. - int64_t offset_M = 0; - int64_t accumulated_x_scale = 0; - int64_t accumulated_w_scale = 0; - for (int i = 0; i < group_index; i++) { - offset_M += M_sizes[i]; - /* It's calculated this way since the scales are at least padded to - multiples of (128, 4), and there is a group of 32 elements per scale. - */ - accumulated_w_scale += - (((N + 128 - 1) / 128) * 128 * ((K + 4 - 1) / 4) * 4 / 32); + + // Offsets for this group. + int64_t xq_offset = 0; + int64_t wq_offset = 0; + int64_t output_offset = 0; + int64_t x_scale_offset = 0; + int64_t w_scale_offset = 0; + + auto round_up = [](int64_t x, int64_t y) { return ((x + y - 1) / y) * y; }; + + // Pre-compute common rounded values to minimize round_up calls + const int64_t N_rounded = round_up(N, 128); + const int64_t M_rounded = round_up(M, 128); + + // Handle offsets API (torch compliant API for 2D-2D and 2D-3D inputs from + // mx8mx8bf16_grouped) + CUDA_KERNEL_ASSERT( + offsets != nullptr && + "offsets must be set for 2d-2d and 2d-3d grouped GEMMs"); + switch (gemm_type) { + // In the 2d-2d case, contraction dim (total_K) has variable group + // sizes. XQ = (M, total_K) WQ = (N, total_K) Main loop defined with WQ + // @ XQ^T = (N, M) for each group. out = (G, N, M) + case GroupedGemmInputType::_2D2D: { + // `offsets` contains end index of each group. + const int32_t prev_group_end_offset = + (group_index == 0) ? 0 : offsets[group_index - 1]; + const int32_t curr_group_end_offset = offsets[group_index]; + const int32_t K_group_size = + curr_group_end_offset - prev_group_end_offset; + + // Validate group offsets. + int align = 128 / cutlass::sizeof_bits::value; + CUDA_KERNEL_ASSERT( + K_group_size % align == 0 && + "for 2d-2d grouped gemm, group sizes along K dim must be non-negative multiple of 16\n"); + CUDA_KERNEL_ASSERT( + curr_group_end_offset <= K && + "for 2d-2d grouped gemm, group end offsets must be non-negative and must be <= K\n"); + + // Set starting input offsets for this group. + // XQ is shape (M,K) with strides (K, 1) and group offsets are along + // the K dim, so: xq_offset -> prev_group_end_offset * 1 + xq_offset = prev_group_end_offset; + + // WQ is shape (N,K) with strides (K, 1) and group offsets are along + // the K dim, so: wq_offset -> prev_group_end_offset * 1 + wq_offset = prev_group_end_offset; + + // Output for 2d-2d grouped GEMM is shape (G, M, N) + // output_offset -> group_index rows with stride of M * N + output_offset = group_index * M * N; + + // Group sizes are variable and converted to blocked/padded format, so + // to calculate the starting offset of this group's scales, we do the + // following: For each previous group + // - Calculate the expected size of its blocked formatted scales + // - Increment the scale offsets by that size + // x_scale shape (M_rounded, total_K_padded_per_group). + // w_scale has shape (N_rounded, total_K_padded_per_group). + for (int i = 0; i < group_index; i++) { + int group_i_size = i == 0 ? offsets[i] : offsets[i] - offsets[i - 1]; + int scale_cols_for_group_i_padded = round_up(group_i_size / 32, 4); + x_scale_offset += M_rounded * scale_cols_for_group_i_padded; + w_scale_offset += N_rounded * scale_cols_for_group_i_padded; + } + + // Only write kernel args if this group is non-empty + if (K_group_size > 0) { + // Get index automatically for this group + int total_K = K; // Name alias for clarity/readability. + + // Set problem shape. + // Main loop passes inputs in B,A order, so we have: (N, K_group) @ + // (M, K_group)^T = (N, M) for each group. + problem_shape_ptr[group_index] = ProblemShape(N, M, K_group_size); + + // Set pointers for this group. + xq_ptr[group_index] = xq + xq_offset; + wq_ptr[group_index] = wq + wq_offset; + x_scale_ptr[group_index] = x_scale + x_scale_offset; + w_scale_ptr[group_index] = w_scale + w_scale_offset; + output_ptr[group_index] = output + output_offset; + + // Set strides. + // TODO: make strides configurable to handle all NT/TN/NN/NT layouts + // that Blackwell supports. For XQ, the group processes a slice (M, + // K_group_size) but it's part of a larger tensor (M, total_K). The + // stride needs to reflect that rows are separated by total_K + // elements in the original tensor. + stride_a_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(M), int(total_K), 1)); + + // For WQ, the group processes a slice (N, K_group_size) but it's + // part of a larger tensor (N, total_K). The stride needs to reflect + // that rows are separated by total_K elements in the original + // tensor. + stride_b_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(N), int(total_K), 1)); + + // For output of this group, (M, K_group_size) @ (N, K_group_size)^T + // = (M, N) + stride_c_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(M), 1)); + + // Set layouts for scale factors. + // Groups of variable size are along the K dim, so we need to + // calculate the size of the blocked group scale factor here. + layout_SFA[group_index] = + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(int(M), int(N), int(K_group_size), 1)); + layout_SFB[group_index] = + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(int(M), int(N), int(K_group_size), 1)); + } + break; + } + case GroupedGemmInputType::_2D3D: { + // `offsets` contains end index of each group. + const int32_t prev_group_end_offset = + (group_index == 0) ? 0 : offsets[group_index - 1]; + const int32_t curr_group_end_offset = offsets[group_index]; + const int32_t M_group_size = + curr_group_end_offset - prev_group_end_offset; + + if (M_group_size > 0) { + // Validate group offsets. + CUDA_KERNEL_ASSERT( + curr_group_end_offset <= M && + "for 2d-3d grouped gemm, group end offsets must be non-negative and must be <= M\n"); + + // Compute starting offset for this group when M_group size > 0 + int64_t group_offset_M = + group_index == 0 ? 0 : offsets[group_index - 1]; + int64_t scale_group_offset_M = 0; + for (int i = 0; i < group_index; i++) { + // Group offset on XQ along total_M dim is the sum of all previous + // group sizes. + int group_i_size = + i == 0 ? offsets[i] : offsets[i] - offsets[i - 1]; + + // Scale group offset on x_scale is sum of all previous scale + // group sizes. + int scale_group_rows_padded = round_up(group_i_size, 128); + scale_group_offset_M += scale_group_rows_padded; + } + + // wq_offset -> group_offset_M rows with stride of K + xq_offset = group_offset_M * K; + + // wq_offset -> group_index rows with stride of N * K (3d tensor) + wq_offset = group_index * N * K; + + // output_offset -> group_offset_M rows with stride of N + output_offset = group_offset_M * N; + + // x_scale offset -> sum of all padded group sizes (rows) * rounded + // scale group cols + const int64_t K_div_32_rounded = round_up(K / 32, 4); + x_scale_offset = scale_group_offset_M * K_div_32_rounded; + + // w_scale_offset -> group_index rows with stride of (N rounded to + // nearest multiple of 128 * K rounded to nearest multiple of 4) + w_scale_offset = group_index * N_rounded * K_div_32_rounded; + + // Set problem shape + problem_shape_ptr[group_index] = ProblemShape(N, M_group_size, K); + + // Set pointers + xq_ptr[group_index] = xq + xq_offset; + wq_ptr[group_index] = wq + wq_offset; + x_scale_ptr[group_index] = x_scale + x_scale_offset; + w_scale_ptr[group_index] = w_scale + w_scale_offset; + output_ptr[group_index] = output + output_offset; + + // Set strides + stride_a_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(M_group_size), int(K), 1)); + stride_b_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(N), int(K), 1)); + stride_c_ptr[group_index] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(M_group_size), 1)); + + // Set layouts for scale factors + layout_SFA[group_index] = + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(int(M_group_size), int(N), int(K), 1)); + layout_SFB[group_index] = + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(int(M_group_size), int(N), int(K), 1)); + } + break; } - accumulated_x_scale = starting_row_after_padding[group_index] * K / 32; - // Set the problem shape for this group. - problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K); - // Set input pointers. - xq_ptr[non_zero_idx] = xq + (offset_M * K); - wq_ptr[non_zero_idx] = wq + (group_index * N * K); - x_scale_ptr[non_zero_idx] = x_scale + accumulated_x_scale; - w_scale_ptr[non_zero_idx] = w_scale + accumulated_w_scale; - output_ptr[non_zero_idx] = output + (offset_M * N); - stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( - StrideA{}, cute::make_shape(int(M), int(K), 1)); - stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( - StrideB{}, cute::make_shape(int(N), int(K), 1)); - stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( - StrideC{}, cute::make_shape(int(N), int(M), 1)); - layout_SFA[non_zero_idx] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( - cute::make_shape(int(M), int(N), int(K), 1)); - layout_SFB[non_zero_idx] = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( - cute::make_shape(int(M), int(N), int(K), 1)); } } } +/* + MXFP8 grouped GEMM that performs, which handles 2 cases: + + 1. XQ 2d, WQ 3d: + XQ shape = (total_M, K) where groups are along the M dimension + WQ shape = (N, K) + out shape = (total_M, N) + + 2. XQ 2d, WQ 2d: + XQ shape = (M, total_K) where groups are along the K dimension + WQ shape = (N, total_K) where groups are along the K dimension + out shape = (num_groups, M, N) +*/ template < typename InputType, int TB_M, @@ -199,9 +302,7 @@ at::Tensor mx8mx8bf16_grouped_impl( InputType w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding) { + at::Tensor offsets) { // The number of groups the kernel uses may vary. int kernel_groups = G; @@ -213,6 +314,11 @@ at::Tensor mx8mx8bf16_grouped_impl( return output; } + // WQ is shape (K,N) or (E,K,N) in column major layout, to align with + // torch._scaled_grouped_mm. We transpose here to match cutlass kernel + // requirements. + InputType WQ_contig = WQ.transpose(-2, -1); + // Define gemm configuration. using ProblemShape = cutlass::gemm::GroupProblemShape>; @@ -226,9 +332,8 @@ at::Tensor mx8mx8bf16_grouped_impl( typename cutlass::layout::LayoutTranspose::type; using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; - constexpr int AlignmentA = 32; - constexpr int AlignmentB = 32; - using ElementGlobalScale = float; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; using ElementAccumulator = float; using ArchTag = cutlass::arch::Sm100; using StageCountType = cutlass::gemm::collective::StageCountAuto; @@ -289,7 +394,7 @@ at::Tensor mx8mx8bf16_grouped_impl( using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideC = typename Gemm::GemmKernel::InternalStrideD; - using ElementComputeEpilogue = typename ElementA::ScaleFactorType; + using ScaleDtype = typename ElementA::ScaleFactorType; using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop:: InternalLayoutSFA; // Scale Factor tensors have an interleaved layout. @@ -315,11 +420,11 @@ at::Tensor mx8mx8bf16_grouped_impl( // X block scales. const int64_t x_scale_offset = wq_offset + wq_size_buffer; - int64_t x_scale_buffer = _byte_align(G * sizeof(ElementComputeEpilogue**)); + int64_t x_scale_buffer = _byte_align(G * sizeof(ScaleDtype**)); // W block scales. const int64_t w_scale_offset = x_scale_offset + x_scale_buffer; - int64_t w_scale_buffer = _byte_align(G * sizeof(ElementComputeEpilogue**)); + int64_t w_scale_buffer = _byte_align(G * sizeof(ScaleDtype**)); // Outputs. const int64_t output_offset = w_scale_offset + w_scale_buffer; @@ -363,12 +468,10 @@ at::Tensor mx8mx8bf16_grouped_impl( reinterpret_cast(kernel_args_ptr + xq_offset); const ElementB** wq_ptr = reinterpret_cast(kernel_args_ptr + wq_offset); - const ElementComputeEpilogue** x_scale_ptr = - reinterpret_cast( - kernel_args_ptr + x_scale_offset); - const ElementComputeEpilogue** w_scale_ptr = - reinterpret_cast( - kernel_args_ptr + w_scale_offset); + const ScaleDtype** x_scale_ptr = + reinterpret_cast(kernel_args_ptr + x_scale_offset); + const ScaleDtype** w_scale_ptr = + reinterpret_cast(kernel_args_ptr + w_scale_offset); ElementC** output_ptr = reinterpret_cast(kernel_args_ptr + output_offset); StrideA* stride_a_ptr = @@ -388,82 +491,63 @@ at::Tensor mx8mx8bf16_grouped_impl( using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - TORCH_CHECK( - !zero_start_index_M.has_value() || - zero_start_index_M->dtype() == at::kLong, - "zero_start_index_M must be int64."); - - TORCH_CHECK( - !M_sizes.has_value() || M_sizes->dtype() == at::kLong, - "M_sizes must be int64."); - // When m_offsets is used, XQ is shape [total_M, K]. When zero_start_index_M - // is used, shape is [G, M, K]. - int64_t M = XQ.size(XQ.dim() - 2); - int64_t N = WQ.size(1); - int64_t K = WQ.size(2); - - // Calculate the number of scale elements per group - int64_t num_x_scale_per_group; - int64_t num_w_scale_per_group; - TORCH_CHECK( - x_scale.dim() == 2 || x_scale.dim() == 3, - "x_scale must be either 2D or 3D tensor") - if (x_scale.dim() == 3) { - num_x_scale_per_group = x_scale.size(1) * x_scale.size(2); - } else { - num_x_scale_per_group = x_scale.size(1); - } + TORCH_CHECK(x_scale.dim() == 2, "x_scale must be a 2D tensor"); TORCH_CHECK( w_scale.dim() == 2 || w_scale.dim() == 3, - "w_scale must be either 2D or 3D tensor") - if (w_scale.dim() == 3) { - num_w_scale_per_group = w_scale.size(1) * w_scale.size(2); + "w_scale must be either 2D or 3D tensor"); + + int64_t M = XQ.size(0); + int64_t N = WQ_contig.size(-2); + int64_t K = WQ_contig.size(-1); + int32_t* offsets_ptr = reinterpret_cast(offsets.data_ptr()); + + // Determine gemm type. + GroupedGemmInputType gemm_type; + if (XQ.dim() == 2 && WQ_contig.dim() == 2) { + gemm_type = GroupedGemmInputType::_2D2D; + } else if (XQ.dim() == 2 && WQ_contig.dim() == 3) { + gemm_type = GroupedGemmInputType::_2D3D; } else { - num_w_scale_per_group = w_scale.size(1); + TORCH_CHECK( + false, + "Invalid input dimensions. MXFP8 grouped GEMM currently only supports 2D-2D and 2D-3D inputs."); } - int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.value().data_ptr()); - int64_t* starting_row_after_padding_ptr = - reinterpret_cast(starting_row_after_padding.value().data_ptr()); + // Execute kernel to dynamically set kernel arguments for each group. set_stacked_kernel_args_kernel< ProblemShape::UnderlyingProblemShape, ElementA, ElementB, ElementC, - ElementComputeEpilogue, + ScaleDtype, StrideA, StrideB, StrideC, LayoutSFA, LayoutSFB, - ElementGlobalScale, Sm1xxBlkScaledConfig><<<1, G, 0, stream>>>( G, + M, N, K, - num_x_scale_per_group, - num_w_scale_per_group, problem_shape_ptr, reinterpret_cast(XQ.data_ptr()), xq_ptr, - reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(WQ_contig.data_ptr()), wq_ptr, - reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr()), x_scale_ptr, - reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), w_scale_ptr, reinterpret_cast(output.data_ptr()), output_ptr, stride_a_ptr, stride_b_ptr, stride_c_ptr, - M_sizes_ptr, + offsets_ptr, layout_SFA, layout_SFB, - starting_row_after_padding_ptr); - // Set the number of groups to the kernel to be at most the number of - // non-zero rows. - kernel_groups = int(std::min(M, G)); + gemm_type); cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with @@ -480,14 +564,16 @@ at::Tensor mx8mx8bf16_grouped_impl( typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, {kernel_groups, problem_shape_ptr, nullptr}, - {reinterpret_cast(wq_ptr), - stride_b_ptr, - reinterpret_cast(xq_ptr), - stride_a_ptr, - reinterpret_cast(w_scale_ptr), - layout_SFB, - reinterpret_cast(x_scale_ptr), - layout_SFA}, + { + reinterpret_cast(wq_ptr), + stride_b_ptr, + reinterpret_cast(xq_ptr), + stride_a_ptr, + reinterpret_cast(w_scale_ptr), + layout_SFB, + reinterpret_cast(x_scale_ptr), + layout_SFA, + }, {{}, nullptr, stride_c_ptr, output_ptr, stride_c_ptr}, hw_info}; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh index e2b76186a2..86c6946895 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mx8mx8bf16_grouped/mx8mx8bf16_grouped_manifest.cuh @@ -19,9 +19,7 @@ at::Tensor mx8mx8bf16_grouped_128_64_256_1_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding); + at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( at::Tensor XQ, // FP8 @@ -30,9 +28,7 @@ at::Tensor mx8mx8bf16_grouped_128_128_256_1_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding); + at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( at::Tensor XQ, // FP8 @@ -41,9 +37,7 @@ at::Tensor mx8mx8bf16_grouped_256_64_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding); + at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( at::Tensor XQ, // FP8 @@ -52,9 +46,7 @@ at::Tensor mx8mx8bf16_grouped_256_128_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding); + at::Tensor offsets); at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( at::Tensor XQ, // FP8 @@ -63,21 +55,17 @@ at::Tensor mx8mx8bf16_grouped_256_256_256_2_1_1( at::Tensor w_scale, at::Tensor output, int64_t G, - std::optional zero_start_index_M, - std::optional M_sizes, - std::optional starting_row_after_padding); + at::Tensor offsets); template using Kernel_mx8mx8bf16_grouped = at::Tensor (*)( - InputType, - InputType, - InputType, - InputType, - at::Tensor, - int64_t, - std::optional, - std::optional, - std::optional); + InputType, // XQ + InputType, // WQ + InputType, // x_scale + InputType, // w_scale + at::Tensor, // output + int64_t, // G + at::Tensor); // offsets template const std::unordered_map>& diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index e80d0dc1fa..316bfd8a1b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -64,13 +64,13 @@ at::Tensor f4f4bf16_grouped_stacked( std::optional global_scale = std::nullopt, std::optional starting_row_after_padding = std::nullopt, bool use_mx = true); -at::Tensor mx8mx8bf16_grouped_stacked( +at::Tensor mx8mx8bf16_grouped_mm( at::Tensor XQ, at::Tensor WQ, at::Tensor x_scale, at::Tensor w_scale, - at::Tensor M_sizes, - std::optional starting_row_after_padding = std::nullopt); + at::Tensor offsets, + std::optional output = std::nullopt); at::Tensor f8f8bf16( at::Tensor XQ, at::Tensor WQ, @@ -320,7 +320,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f4f4bf16", f4f4bf16); m.impl("f4f4bf16_grouped", f4f4bf16_grouped); m.impl("f4f4bf16_grouped_stacked", f4f4bf16_grouped_stacked); - m.impl("mx8mx8bf16_grouped_stacked", mx8mx8bf16_grouped_stacked); + m.impl("mx8mx8bf16_grouped_mm", mx8mx8bf16_grouped_mm); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("bf16_fast_gemv", bf16_fast_gemv); @@ -376,7 +376,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f4f4bf16", f4f4bf16); m.impl("f4f4bf16_grouped", f4f4bf16_grouped); m.impl("f4f4bf16_grouped_stacked", f4f4bf16_grouped_stacked); - m.impl("mx8mx8bf16_grouped_stacked", mx8mx8bf16_grouped_stacked); + m.impl("mx8mx8bf16_grouped_mm", mx8mx8bf16_grouped_mm); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("bf16_fast_gemv", bf16_fast_gemv); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index dd6f949338..9f31dcd507 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -25,7 +25,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f4f4bf16_grouped_stacked(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor M_sizes, Tensor? global_scale=None, Tensor? starting_row_after_padding=None, bool use_mx=True) -> Tensor"); m.def( - "mx8mx8bf16_grouped_stacked(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor M_sizes, Tensor? starting_row_after_padding=None) -> Tensor"); + "mx8mx8bf16_grouped_mm(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor offsets, Tensor(a!)? output=None) -> Tensor"); m.def( "f8f8bf16(Tensor XQ, Tensor WQ, Tensor scale, bool use_fast_accum=True) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 5638f40144..12bfe92e60 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -160,6 +160,44 @@ def sample_scales() -> st.SearchStrategy[Optional[torch.Tensor]]: ) +# Source: https://github.com/pytorch/ao/blob/568c1932a16ae9f30d48da214a88dc0013e98ed8/torchao/prototype/moe_training/utils.py#L310 +def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"): + """ + Utility function for tests and benchmarks. + + Generates a tensor of length E, containing random values divisible by `multiple_of`, + from 0 to M, in sorted order, and where the final value in the tensor is always M. + Args: + E (int): The length of the tensor. + M (int): The maximum value in the tensor. + Returns: + torch.Tensor: A tensor of length E with the specified properties. + """ + import random + + # Ensure M is divisible by 16 + if M % multiple_of != 0: + raise ValueError(f"M must be divisible by {multiple_of}") + + # Generate a list of possible values + possible_values = list(range(multiple_of, M + 1, multiple_of)) + + # If E is larger than the number of possible values, raise an error + if E > len(possible_values): + raise ValueError("E cannot be larger than the number of possible values") + + # Randomly select E - 1 values from the possible values (excluding M) + selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) + + # Append M to the selected values + selected_values = torch.cat((selected_values, torch.tensor([M]))) + + # Sort the selected values + selected_values, _ = torch.sort(selected_values) + + return selected_values.to(dtype).to(device) + + @unittest.skipIf( not torch.cuda.is_available(), "Skip when no GPU is available. This test is only for GPU.", @@ -1224,6 +1262,110 @@ def test_grouped_gemm_2d_3d( # BF16 loopover gemm reference self.bf16_loopover_validate(x_group, W, y_fp8_group, y_bf16_group) + @unittest.skipIf(not SUPPORTS_MXFP8, "MXFP8 not supported on this platform") + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 16]), + K=st.sampled_from([2048, 3584]), + N=st.sampled_from([256, 1024, 6144]), + M=st.sampled_from([256, 512, 3584]), + ) + def test_mx_grouped_gemm_2d_2d( + self, + G: int, + M: int, + N: int, + K: int, + ) -> None: + # Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`, + # where we use "K" as the contracting dim which has "G" groups. + from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import to_mxfp8 + + total_K = K # Alias for clarity, communicating this consists of several groups along this dim + input_group_end_offsets = generate_jagged_offs( + G, total_K, multiple_of=32, device=self.device + ) + X = torch.randn((M, total_K), dtype=torch.bfloat16, device=self.device) * 0.1 + W = torch.randn((N, total_K), dtype=torch.bfloat16, device=self.device) * 0.01 + + # Convert scales to blocked format. + x_list = [] + w_list = [] + x_blocked_scale_list = [] + w_blocked_scale_list = [] + + def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + for group_idx in range(G): + # to_mxfp8 per group + prev_group_end_offset = ( + 0 if group_idx == 0 else input_group_end_offsets[group_idx - 1] + ) + curr_group_end_offset = input_group_end_offsets[group_idx] + group_size = curr_group_end_offset - prev_group_end_offset + if group_size > 0: + x_slice = X[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (M, K_group) + w_slice = W[ + :, prev_group_end_offset:curr_group_end_offset + ].contiguous() # (N, K_group) + x_scale_slice, xq_slice = to_mxfp8( + x_slice + ) # scale shape -> (M, K_group // 32) + w_scale_slice, wq_slice = to_mxfp8( + w_slice + ) # scale shape -> (N, K_group // 32) + x_list.append(xq_slice) + w_list.append(wq_slice) + + # Convert scales to blocked format. + x_scale_slice_blocked = _to_blocked( + x_scale_slice + ) # (round_up(M, 128), round_up(K_group//32, 4)) + w_scale_slice_blocked = _to_blocked( + w_scale_slice + ) # (round_up(N, 128), round_up(K_group//32, 4)) + x_blocked_scale_list.append(x_scale_slice_blocked) + w_blocked_scale_list.append(w_scale_slice_blocked) + + # Assemble the full XQ and WQ + xq = torch.cat(x_list, dim=1).contiguous() + wq = torch.cat(w_list, dim=1).contiguous() + + # Combine all XQ groups blocked scales into one tensor. + x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0) + M_rounded = round_up(M, 128) + x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1) + + # Combine all WQ groups blocked scales into one tensor. + w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0) + N_rounded = round_up(N, 128) + w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1) + + # Compute mxfp8 grouped mm output + out = torch.empty((G, M, N), dtype=torch.bfloat16, device=self.device) + y_mxfp8 = torch.ops.fbgemm.mx8mx8bf16_grouped_mm( + xq, # (M, total_K) + wq.transpose(-2, -1), # (total_K, N) + x_blocked_scales, # to_blocked_per_group(M, total_K//32) + w_blocked_scales, # to_blocked_per_group(N, total_K//32) + input_group_end_offsets, # (G,) + out, # (G, M, N) + ) + + # bf16 reference output + y_bf16 = torch._grouped_mm( + X, W.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16 + ) + + # Assert no NaNs + assert not y_mxfp8.isnan().any(), "mxfp8 output contains NaN" + + # Assert outputs are close + torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf(not SUPPORTS_MXFP8, "MXFP8 not supported on this platform") @settings(deadline=None) @given( @@ -1232,7 +1374,7 @@ def test_grouped_gemm_2d_3d( N=st.sampled_from([256, 1024, 6144]), K=st.sampled_from([256, 512, 3584]), ) - def test_mx_grouped_gemm( + def test_mx_grouped_gemm_2d_3d( self, G: int, M: int, @@ -1241,15 +1383,21 @@ def test_mx_grouped_gemm( ) -> None: from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import to_mxfp8 - X = torch.randn((G, M, K), dtype=torch.bfloat16, device=self.device) * 0.1 + # Simulate 2d-3d grouped gemm `out = input @ weight.t()` + # 2D inputs with groups along M, 3D weights. + block_size = 32 + total_M = M # Alias for clarity that M dim contains groups. + X = torch.randn((total_M, K), dtype=torch.bfloat16, device=self.device) * 0.1 W = torch.randn((G, N, K), dtype=torch.bfloat16, device=self.device) * 0.01 + input_group_end_offsets = generate_jagged_offs( + G, total_M, multiple_of=32, device=self.device + ) - m_values = [i.shape[0] for i in X] - m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=X[0].device) - + # For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. wq_list = [] w_scale_list = [] - for i in range(m_sizes.shape[0]): + for i in range(G): w_scale, wq = to_mxfp8(W[i]) w_scale = _to_blocked(w_scale) wq_list.append(wq) @@ -1257,47 +1405,45 @@ def test_mx_grouped_gemm( wq = torch.stack(wq_list, dim=0).contiguous() w_scale = torch.stack(w_scale_list, dim=0).contiguous() - starting_row_after_padding_list = [0] + # For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately, + # as they each used for independent gemm in the grouped gemm. xq_list = [] x_scale_list = [] - for i in range(m_sizes.shape[0]): - scale_slice = X[i] - if m_sizes[i].item() != 0: - x_scale, xq = to_mxfp8(scale_slice) + for i in range(G): + prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1] + curr_group_end = input_group_end_offsets[i] + group_size = curr_group_end - prev_group_end + if group_size > 0: + x_slice = X[prev_group_end:curr_group_end, :] + x_scale, xq = to_mxfp8(x_slice) x_scale = _to_blocked(x_scale) xq_list.append(xq) x_scale_list.append(x_scale) - starting_row_after_padding_list.append( - starting_row_after_padding_list[i] - + x_scale.numel() // (X[0].shape[1] // 32) - ) - else: - starting_row_after_padding_list.append( - starting_row_after_padding_list[i] - ) - starting_row_after_padding = torch.tensor( - starting_row_after_padding_list, device=xq.device - ) - xq = torch.cat(xq_list, dim=0).contiguous() x_scale = torch.cat(x_scale_list, dim=0).contiguous() - x_scale = x_scale.reshape(-1, X[0].shape[-1] // 32) + x_scale = x_scale.reshape(-1, K // block_size) xq = xq.view(-1, xq.shape[-1]) - y_mxfp8 = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked( + # Compute mxfp8 grouped gemm. + out = torch.empty((total_M, N), dtype=torch.bfloat16, device=self.device) + y_mxfp8 = torch.ops.fbgemm.mx8mx8bf16_grouped_mm( xq, - wq, + wq.transpose(-2, -1), x_scale, w_scale, - m_sizes, - starting_row_after_padding=starting_row_after_padding, + input_group_end_offsets, + out, ) - y_bf16_group = [] - for i in range(G): - y_bf16_group.append(torch.matmul(X[i], W[i].t())) - y_bf16 = torch.cat(y_bf16_group, dim=0) + # Compute reference bf16 grouped gemm. + y_bf16 = torch._grouped_mm( + X, + W.transpose(-2, -1), + offs=input_group_end_offsets, + out_dtype=torch.bfloat16, + ) + # Assert outputs are close. torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) @unittest.skipIf(