Skip to content

Commit f809b74

Browse files
Add 2d-2d support to MXFP8 Grouped GEMM (#4816)
Summary: X-link: facebookresearch/FBGEMM#1846 ## MXFP8 grouped GEMM updates to (1) handle 2d-2d case, and (2) have a PyTorch compliant API - Add support for 2d-2d inputs with dynamic groups along K dimension - Added tests to ensure correct numerics for both 2d-2d and 2d-3d cases, with randomly group sizes - Add benchmarks for both 2d-3d and 2d-2d cases Reviewed By: ngimel, cthi Differential Revision: D81362680
1 parent 10367cc commit f809b74

12 files changed

+718
-417
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 150 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,78 +2916,194 @@ def cuda(self) -> bool:
29162916

29172917

29182918
@register_quantize_op
2919-
class MXFP8StackedGroupedGemm(QuantizeOpBase):
2919+
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
29202920
"""
2921-
MXFP8 grouped matmul with blockwise scaling and stacked inputs.
2921+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
29222922
"""
29232923

29242924
def preprocess(self, x, w):
2925-
m_values = [i.shape[0] for i in x]
2926-
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2925+
assert isinstance(x, list)
2926+
assert isinstance(w, list)
2927+
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
2928+
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
2929+
return x, w
2930+
2931+
def quantize(self, x, w):
2932+
block_size = 32
2933+
G, N, K = w.shape
2934+
total_M = x.shape[0]
2935+
group_size = total_M // G
2936+
input_group_end_offsets = torch.arange(
2937+
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
2938+
)
2939+
2940+
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
2941+
# as they each used for independent gemm in the grouped gemm.
29272942
wq_list = []
29282943
w_scale_list = []
2929-
for i in range(m_sizes.shape[0]):
2944+
for i in range(G):
29302945
w_scale, wq = to_mxfp8(w[i])
29312946
w_scale = _to_blocked(w_scale)
29322947
wq_list.append(wq)
29332948
w_scale_list.append(w_scale)
29342949
wq = torch.stack(wq_list, dim=0).contiguous()
29352950
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
2936-
return x, wq, w_scale, m_sizes
29372951

2938-
def quantize(self, x, wq, w_scale, m_sizes):
2939-
starting_row_after_padding_list = [0]
2952+
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
2953+
# as they each used for independent gemm in the grouped gemm.
29402954
xq_list = []
29412955
x_scale_list = []
2942-
for i in range(m_sizes.shape[0]):
2943-
scale_slice = x[i]
2944-
if m_sizes[i].item() != 0:
2945-
x_scale, xq = to_mxfp8(scale_slice)
2956+
for i in range(G):
2957+
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
2958+
curr_group_end = input_group_end_offsets[i]
2959+
group_size = curr_group_end - prev_group_end
2960+
if group_size > 0:
2961+
x_slice = x[prev_group_end:curr_group_end, :]
2962+
x_scale, xq = to_mxfp8(x_slice)
29462963
x_scale = _to_blocked(x_scale)
29472964
xq_list.append(xq)
29482965
x_scale_list.append(x_scale)
2949-
starting_row_after_padding_list.append(
2950-
starting_row_after_padding_list[i]
2951-
+ x_scale.numel() // (x[0].shape[1] // 32)
2952-
)
2953-
else:
2954-
starting_row_after_padding_list.append(
2955-
starting_row_after_padding_list[i]
2956-
)
29572966
xq = torch.cat(xq_list, dim=0).contiguous()
29582967
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
2959-
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
2968+
x_scale = x_scale.reshape(-1, K // block_size)
29602969
xq = xq.view(-1, xq.shape[-1])
2961-
return (
2970+
return xq, wq, x_scale, w_scale, input_group_end_offsets
2971+
2972+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
2973+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
29622974
xq,
2963-
wq,
2975+
wq.transpose(-2, -1),
29642976
x_scale,
29652977
w_scale,
2966-
m_sizes,
2967-
torch.tensor(starting_row_after_padding_list, device=xq.device),
2978+
input_group_end_offsets,
29682979
)
29692980

2970-
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
2971-
return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
2981+
def quantize_and_compute(self, x, w):
2982+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
2983+
return self.compute(
29722984
xq,
29732985
wq,
29742986
x_scale,
29752987
w_scale,
2976-
m_sizes,
2977-
starting_row_after_padding=starting_row_after_padding,
2988+
input_group_end_offsets,
29782989
)
29792990

2980-
def quantize_and_compute(self, x, w):
2981-
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
2982-
x, w
2991+
@property
2992+
def name(self) -> str:
2993+
return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"
2994+
2995+
@property
2996+
def hip(self) -> bool:
2997+
return False
2998+
2999+
@property
3000+
def cuda(self) -> bool:
3001+
return True
3002+
3003+
3004+
@register_quantize_op
3005+
class MXFP8GroupedGemm2d2d(QuantizeOpBase):
3006+
"""
3007+
MXFP8 grouped GEMM with 2D inputs and 3D weights.
3008+
"""
3009+
3010+
def preprocess(self, x, w):
3011+
assert isinstance(x, list)
3012+
assert isinstance(w, list)
3013+
G = len(x)
3014+
x = torch.cat(x, dim=1).contiguous() # (M, total_K)
3015+
w = torch.cat(w, dim=1).contiguous() # (N, total_K)
3016+
return x, w, G
3017+
3018+
def quantize(self, x, w, G):
3019+
# Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
3020+
# where we use "K" as the contracting dim which has "G" groups.
3021+
M, total_K = x.shape
3022+
N, _ = w.shape
3023+
group_size = total_K // G
3024+
input_group_end_offsets = torch.arange(
3025+
group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
3026+
)
3027+
3028+
# Convert scales to blocked format.
3029+
x_list = []
3030+
w_list = []
3031+
x_blocked_scale_list = []
3032+
w_blocked_scale_list = []
3033+
3034+
def round_up(x: int, y: int) -> int:
3035+
return ((x + y - 1) // y) * y
3036+
3037+
for group_idx in range(G):
3038+
# to_mxfp8 per group
3039+
prev_group_end_offset = (
3040+
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
3041+
)
3042+
curr_group_end_offset = input_group_end_offsets[group_idx]
3043+
group_size = curr_group_end_offset - prev_group_end_offset
3044+
if group_size > 0:
3045+
x_slice = x[
3046+
:, prev_group_end_offset:curr_group_end_offset
3047+
].contiguous() # (M, K_group)
3048+
w_slice = w[
3049+
:, prev_group_end_offset:curr_group_end_offset
3050+
].contiguous() # (N, K_group)
3051+
x_scale_slice, xq_slice = to_mxfp8(
3052+
x_slice
3053+
) # scale shape -> (M, K_group // 32)
3054+
w_scale_slice, wq_slice = to_mxfp8(
3055+
w_slice
3056+
) # scale shape -> (N, K_group // 32)
3057+
x_list.append(xq_slice)
3058+
w_list.append(wq_slice)
3059+
3060+
# Convert scales to blocked format.
3061+
x_scale_slice_blocked = _to_blocked(
3062+
x_scale_slice
3063+
) # (round_up(M, 128), round_up(K_group//32, 4))
3064+
w_scale_slice_blocked = _to_blocked(
3065+
w_scale_slice
3066+
) # (round_up(N, 128), round_up(K_group//32, 4))
3067+
x_blocked_scale_list.append(x_scale_slice_blocked)
3068+
w_blocked_scale_list.append(w_scale_slice_blocked)
3069+
3070+
# Assemble the full XQ and WQ
3071+
xq = torch.cat(x_list, dim=1).contiguous()
3072+
wq = torch.cat(w_list, dim=1).contiguous()
3073+
3074+
# Combine all XQ groups blocked scales into one tensor.
3075+
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
3076+
M_rounded = round_up(M, 128)
3077+
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)
3078+
3079+
# Combine all WQ groups blocked scales into one tensor.
3080+
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
3081+
N_rounded = round_up(N, 128)
3082+
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
3083+
return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets
3084+
3085+
def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
3086+
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
3087+
xq,
3088+
wq.transpose(-2, -1),
3089+
x_scale,
3090+
w_scale,
3091+
input_group_end_offsets,
29833092
)
3093+
3094+
def quantize_and_compute(self, x, w):
3095+
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
29843096
return self.compute(
2985-
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
3097+
xq,
3098+
wq,
3099+
x_scale,
3100+
w_scale,
3101+
input_group_end_offsets,
29863102
)
29873103

29883104
@property
29893105
def name(self) -> str:
2990-
return "cutlass_mx8mx8bf16_grouped_stacked"
3106+
return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"
29913107

29923108
@property
29933109
def hip(self) -> bool:

0 commit comments

Comments
 (0)