Skip to content

Commit 069b945

Browse files
Add 2d-2d support to MXFP8 Grouped GEMM (#4816)
Summary: ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
1 parent c43677d commit 069b945

12 files changed

+717
-417
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 150 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,78 +2873,194 @@ def cuda(self) -> bool:
28732873

28742874

28752875
@register_quantize_op
2876-
class MXFP8StackedGroupedGemm(QuantizeOpBase):
2876+
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
28772877
"""
2878-
MXFP8 grouped matmul with blockwise scaling and stacked inputs.
2878+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
28792879
"""
28802880

28812881
def preprocess(self, x, w):
2882-
m_values = [i.shape[0] for i in x]
2883-
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2882+
assert isinstance(x, list)
2883+
assert isinstance(w, list)
2884+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
2885+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
2886+
return x, w
2887+
2888+
def quantize(self, x, w):
2889+
block_size = 32
2890+
G, N, K = w.shape
2891+
total_M = x.shape[0]
2892+
group_size = total_M // G
2893+
input_group_end_offsets = torch.arange(
2894+
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
2895+
)
2896+
2897+
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
2898+
# as they each used for independent gemm in the grouped gemm.
28842899
wq_list = []
28852900
w_scale_list = []
2886-
for i in range(m_sizes.shape[0]):
2901+
for i in range(G):
28872902
w_scale, wq = to_mxfp8(w[i])
28882903
w_scale = _to_blocked(w_scale)
28892904
wq_list.append(wq)
28902905
w_scale_list.append(w_scale)
28912906
wq = torch.stack(wq_list, dim=0).contiguous()
28922907
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
2893-
return x, wq, w_scale, m_sizes
28942908

2895-
def quantize(self, x, wq, w_scale, m_sizes):
2896-
starting_row_after_padding_list = [0]
2909+
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
2910+
# as they each used for independent gemm in the grouped gemm.
28972911
xq_list = []
28982912
x_scale_list = []
2899-
for i in range(m_sizes.shape[0]):
2900-
scale_slice = x[i]
2901-
if m_sizes[i].item() != 0:
2902-
x_scale, xq = to_mxfp8(scale_slice)
2913+
for i in range(G):
2914+
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
2915+
curr_group_end = input_group_end_offsets[i]
2916+
group_size = curr_group_end - prev_group_end
2917+
if group_size > 0:
2918+
x_slice = x[prev_group_end:curr_group_end, :]
2919+
x_scale, xq = to_mxfp8(x_slice)
29032920
x_scale = _to_blocked(x_scale)
29042921
xq_list.append(xq)
29052922
x_scale_list.append(x_scale)
2906-
starting_row_after_padding_list.append(
2907-
starting_row_after_padding_list[i]
2908-
+ x_scale.numel() // (x[0].shape[1] // 32)
2909-
)
2910-
else:
2911-
starting_row_after_padding_list.append(
2912-
starting_row_after_padding_list[i]
2913-
)
29142923
xq = torch.cat(xq_list, dim=0).contiguous()
29152924
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
2916-
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
2925+
x_scale = x_scale.reshape(-1, K // block_size)
29172926
xq = xq.view(-1, xq.shape[-1])
2918-
return (
2927+
return xq, wq, x_scale, w_scale, input_group_end_offsets
2928+
2929+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
2930+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
29192931
xq,
2920-
wq,
2932+
wq.transpose(-2, -1),
29212933
x_scale,
29222934
w_scale,
2923-
m_sizes,
2924-
torch.tensor(starting_row_after_padding_list, device=xq.device),
2935+
input_group_end_offsets,
29252936
)
29262937

2927-
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
2928-
return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
2938+
def quantize_and_compute(self, x, w):
2939+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
2940+
return self.compute(
29292941
xq,
29302942
wq,
29312943
x_scale,
29322944
w_scale,
2933-
m_sizes,
2934-
starting_row_after_padding=starting_row_after_padding,
2945+
input_group_end_offsets,
29352946
)
29362947

2937-
def quantize_and_compute(self, x, w):
2938-
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
2939-
x, w
2948+
@property
2949+
def name(self) -> str:
2950+
return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
2951+
2952+
@property
2953+
def hip(self) -> bool:
2954+
return False
2955+
2956+
@property
2957+
def cuda(self) -> bool:
2958+
return True
2959+
2960+
2961+
@register_quantize_op
2962+
class MXFP8GroupedGemm2d2d(QuantizeOpBase):
2963+
"""
2964+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
2965+
"""
2966+
2967+
def preprocess(self, x, w):
2968+
assert isinstance(x, list)
2969+
assert isinstance(w, list)
2970+
G = len(x)
2971+
x = torch.cat(x, dim=1).contiguous() # (M, total_K)
2972+
w = torch.cat(w, dim=1).contiguous() # (N, total_K)
2973+
return x, w, G
2974+
2975+
def quantize(self, x, w, G):
2976+
# Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
2977+
# where we use "K" as the contracting dim which has "G" groups.
2978+
M, total_K = x.shape
2979+
N, _ = w.shape
2980+
group_size = total_K // G
2981+
input_group_end_offsets = torch.arange(
2982+
group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
2983+
)
2984+
2985+
# Convert scales to blocked format.
2986+
x_list = []
2987+
w_list = []
2988+
x_blocked_scale_list = []
2989+
w_blocked_scale_list = []
2990+
2991+
def round_up(x: int, y: int) -> int:
2992+
return ((x + y - 1) // y) * y
2993+
2994+
for group_idx in range(G):
2995+
# to_mxfp8 per group
2996+
prev_group_end_offset = (
2997+
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
2998+
)
2999+
curr_group_end_offset = input_group_end_offsets[group_idx]
3000+
group_size = curr_group_end_offset - prev_group_end_offset
3001+
if group_size > 0:
3002+
x_slice = x[
3003+
:, prev_group_end_offset:curr_group_end_offset
3004+
].contiguous() # (M, K_group)
3005+
w_slice = w[
3006+
:, prev_group_end_offset:curr_group_end_offset
3007+
].contiguous() # (N, K_group)
3008+
x_scale_slice, xq_slice = to_mxfp8(
3009+
x_slice
3010+
) # scale shape -> (M, K_group // 32)
3011+
w_scale_slice, wq_slice = to_mxfp8(
3012+
w_slice
3013+
) # scale shape -> (N, K_group // 32)
3014+
x_list.append(xq_slice)
3015+
w_list.append(wq_slice)
3016+
3017+
# Convert scales to blocked format.
3018+
x_scale_slice_blocked = _to_blocked(
3019+
x_scale_slice
3020+
) # (round_up(M, 128), round_up(K_group//32, 4))
3021+
w_scale_slice_blocked = _to_blocked(
3022+
w_scale_slice
3023+
) # (round_up(N, 128), round_up(K_group//32, 4))
3024+
x_blocked_scale_list.append(x_scale_slice_blocked)
3025+
w_blocked_scale_list.append(w_scale_slice_blocked)
3026+
3027+
# Assemble the full XQ and WQ
3028+
xq = torch.cat(x_list, dim=1).contiguous()
3029+
wq = torch.cat(w_list, dim=1).contiguous()
3030+
3031+
# Combine all XQ groups blocked scales into one tensor.
3032+
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
3033+
M_rounded = round_up(M, 128)
3034+
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
3035+
3036+
# Combine all WQ groups blocked scales into one tensor.
3037+
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
3038+
N_rounded = round_up(N, 128)
3039+
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
3040+
return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
3041+
3042+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3043+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
3044+
xq,
3045+
wq.transpose(-2, -1),
3046+
x_scale,
3047+
w_scale,
3048+
input_group_end_offsets,
29403049
)
3050+
3051+
def quantize_and_compute(self, x, w):
3052+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
29413053
return self.compute(
2942-
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
3054+
xq,
3055+
wq,
3056+
x_scale,
3057+
w_scale,
3058+
input_group_end_offsets,
29433059
)
29443060

29453061
@property
29463062
def name(self) -> str:
2947-
return "cutlass_mx8mx8bf16_grouped_stacked"
3063+
return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
29483064

29493065
@property
29503066
def hip(self) -> bool:

0 commit comments

Comments
 (0)