Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 150 additions & 34 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,78 +2916,194 @@ def cuda(self) -> bool:


@register_quantize_op
class MXFP8StackedGroupedGemm(QuantizeOpBase):
class MXFP8GroupedGemm2d3d(QuantizeOpBase):
"""
MXFP8 grouped matmul with blockwise scaling and stacked inputs.
MXFP8 grouped GEMM with 2D inputs and 3D weights.
"""

def preprocess(self, x, w):
m_values = [i.shape[0] for i in x]
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
assert isinstance(x, list)
assert isinstance(w, list)
x = torch.cat(x, dim=0).contiguous() # (G * M, K)
w = torch.stack(w, dim=0).contiguous() # (G, N, K)
return x, w

def quantize(self, x, w):
block_size = 32
G, N, K = w.shape
total_M = x.shape[0]
group_size = total_M // G
input_group_end_offsets = torch.arange(
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
)

# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(m_sizes.shape[0]):
for i in range(G):
w_scale, wq = to_mxfp8(w[i])
w_scale = _to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes

def quantize(self, x, wq, w_scale, m_sizes):
starting_row_after_padding_list = [0]
# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(m_sizes.shape[0]):
scale_slice = x[i]
if m_sizes[i].item() != 0:
x_scale, xq = to_mxfp8(scale_slice)
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = x[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = _to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
+ x_scale.numel() // (x[0].shape[1] // 32)
)
else:
starting_row_after_padding_list.append(
starting_row_after_padding_list[i]
)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
return (
return xq, wq, x_scale, w_scale, input_group_end_offsets

def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
xq,
wq,
wq.transpose(-2, -1),
x_scale,
w_scale,
m_sizes,
torch.tensor(starting_row_after_padding_list, device=xq.device),
input_group_end_offsets,
)

def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
return self.compute(
xq,
wq,
x_scale,
w_scale,
m_sizes,
starting_row_after_padding=starting_row_after_padding,
input_group_end_offsets,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
x, w
@property
def name(self) -> str:
return "cutlass_mx8mx8bf16_grouped_mm_2d_3d"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class MXFP8GroupedGemm2d2d(QuantizeOpBase):
"""
MXFP8 grouped GEMM with 2D inputs and 3D weights.
"""

def preprocess(self, x, w):
assert isinstance(x, list)
assert isinstance(w, list)
G = len(x)
x = torch.cat(x, dim=1).contiguous() # (M, total_K)
w = torch.cat(w, dim=1).contiguous() # (N, total_K)
return x, w, G

def quantize(self, x, w, G):
# Simulate 2d-2d grouped gemm in backward pass `grad_weight = grad_output_t @ input`,
# where we use "K" as the contracting dim which has "G" groups.
M, total_K = x.shape
N, _ = w.shape
group_size = total_K // G
input_group_end_offsets = torch.arange(
group_size, total_K + 1, group_size, dtype=torch.int32, device=x.device
)

# Convert scales to blocked format.
x_list = []
w_list = []
x_blocked_scale_list = []
w_blocked_scale_list = []

def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

for group_idx in range(G):
# to_mxfp8 per group
prev_group_end_offset = (
0 if group_idx == 0 else input_group_end_offsets[group_idx - 1]
)
curr_group_end_offset = input_group_end_offsets[group_idx]
group_size = curr_group_end_offset - prev_group_end_offset
if group_size > 0:
x_slice = x[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (M, K_group)
w_slice = w[
:, prev_group_end_offset:curr_group_end_offset
].contiguous() # (N, K_group)
x_scale_slice, xq_slice = to_mxfp8(
x_slice
) # scale shape -> (M, K_group // 32)
w_scale_slice, wq_slice = to_mxfp8(
w_slice
) # scale shape -> (N, K_group // 32)
x_list.append(xq_slice)
w_list.append(wq_slice)

# Convert scales to blocked format.
x_scale_slice_blocked = _to_blocked(
x_scale_slice
) # (round_up(M, 128), round_up(K_group//32, 4))
w_scale_slice_blocked = _to_blocked(
w_scale_slice
) # (round_up(N, 128), round_up(K_group//32, 4))
x_blocked_scale_list.append(x_scale_slice_blocked)
w_blocked_scale_list.append(w_scale_slice_blocked)

# Assemble the full XQ and WQ
xq = torch.cat(x_list, dim=1).contiguous()
wq = torch.cat(w_list, dim=1).contiguous()

# Combine all XQ groups blocked scales into one tensor.
x_blocked_scales = torch.cat(x_blocked_scale_list, dim=0)
M_rounded = round_up(M, 128)
x_blocked_scales = x_blocked_scales.reshape(M_rounded, -1)

# Combine all WQ groups blocked scales into one tensor.
w_blocked_scales = torch.cat(w_blocked_scale_list, dim=0)
N_rounded = round_up(N, 128)
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
return xq, wq, x_blocked_scales, w_blocked_scales, input_group_end_offsets

def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
return torch.ops.fbgemm.mx8mx8bf16_grouped_mm(
xq,
wq.transpose(-2, -1),
x_scale,
w_scale,
input_group_end_offsets,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, input_group_end_offsets = self.quantize(x, w)
return self.compute(
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
xq,
wq,
x_scale,
w_scale,
input_group_end_offsets,
)

@property
def name(self) -> str:
return "cutlass_mx8mx8bf16_grouped_stacked"
return "cutlass_mx8mx8bf16_grouped_mm_2d_2d"

@property
def hip(self) -> bool:
Expand Down
Loading
Loading