Skip to content

Commit df502e0

Browse files
row of blocks within groups only
1 parent 11f25ec commit df502e0

File tree

2 files changed

+39
-44
lines changed

2 files changed

+39
-44
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -272,35 +272,41 @@ def test_mxfp8_per_group_blocked_scales_3d(
272272

273273

274274
@skip_if_rocm("ROCm enablement in progress")
275-
@pytest.mark.parametrize("m,total_k,n_groups", [(256, 512, 4)])#, (256, 128, 4), (512, 128, 4), (1024, 128, 4), (1024, 256, 4), (1024, 512, 4), (1024, 1024, 4), (1024, 2048, 4), (1024, 4096, 4), (1024, 8192, 4), (1024, 16384, 4)])
275+
@pytest.mark.parametrize(
276+
"m,total_k,n_groups",
277+
[
278+
(256, 512, 4),
279+
(256, 128, 4),
280+
(512, 128, 4),
281+
(1024, 128, 4),
282+
(1024, 256, 4),
283+
(1024, 512, 4),
284+
(1024, 1024, 4),
285+
(1024, 2048, 4),
286+
(1024, 4096, 4),
287+
(1024, 8192, 4),
288+
(1024, 16384, 4),
289+
(5120, 16640, 16),
290+
],
291+
)
276292
def test_mxfp8_per_group_blocked_scales_2d2d_lhs(
277293
m: int,
278294
total_k: int,
279295
n_groups: int,
280296
):
281297
device = "cuda"
282298
block_size = 32
283-
284-
# Make each group of row blocks have distinct, constinent data for debugging
285-
input_data = torch.cat(
286-
[
287-
torch.ones(m // 2, total_k, device=device),
288-
torch.full((m // 2, total_k), 999, device=device),
289-
]
290-
)
291-
#input_data= torch.randn(m, total_k, device=device)
299+
input_data = torch.randn(m, total_k, device=device)
292300

293301
e8m0_scales, _ = to_mx(
294302
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
295303
)
296304

297305
# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
298-
# input_group_offsets = generate_jagged_offs(
299-
# n_groups, total_k, multiple_of=block_size, device=device
300-
# )
301-
# input_group_offsets //= block_size
302-
input_group_offsets = torch.tensor([3, 8, 12, 16], device=device, dtype=torch.int32)
303-
#print(input_group_offsets)
306+
input_group_offsets = generate_jagged_offs(
307+
n_groups, total_k, multiple_of=block_size, device=device
308+
)
309+
input_group_offsets //= block_size
304310

305311
# torch reference
306312
ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_per_group_2d2d_lhs(
@@ -320,12 +326,6 @@ def test_mxfp8_per_group_blocked_scales_2d2d_lhs(
320326
input_group_offsets,
321327
output_group_offsets,
322328
)
323-
print(ref_start_cols_after_padding)
324-
with open('tmp-ref.txt', 'w') as f:
325-
f.write(str(ref_out_scales.storage()))
326-
with open('tmp-triton.txt', 'w') as f:
327-
f.write(str(triton_out_scales.storage()))
328-
breakpoint()
329329
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
330330
"blocked scales not equal"
331331
)

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,6 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs(
479479
# Output block stride for the rearranged format
480480
BLOCK_ROWS, BLOCK_COLS = 128, 4
481481
output_stride_per_block = BLOCK_ROWS * BLOCK_COLS
482-
num_row_blocks = padded_rows // BLOCK_ROWS
483-
output_stride_per_col_of_blocks = output_stride_per_block * num_row_blocks
484482

485483
# We parallelize per group and per row block.
486484
# Cols per group is variable, so we just loop through col blocks for each group.
@@ -495,17 +493,17 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs(
495493
scales_tensor.stride(1),
496494
rows,
497495
cols,
496+
padded_rows,
498497
num_groups,
499498
# Original offsets (to read from)
500499
input_group_end_offsets,
501500
# Output scales tensor and group offsets after padding (to write to)
502501
output.view(torch.uint8),
503502
output_group_start_offsets,
504503
output_stride_per_block,
505-
output_stride_per_col_of_blocks,
506504
BLOCK_ROWS=BLOCK_ROWS,
507505
BLOCK_COLS=BLOCK_COLS,
508-
DEBUG=True,
506+
DEBUG=False,
509507
)
510508
return output
511509

@@ -517,12 +515,12 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
517515
scales_stride_dim1,
518516
scale_rows,
519517
scale_cols,
518+
padded_rows,
520519
num_groups,
521520
orig_offsets, # (num_groups,)
522521
output_scales_ptr,
523522
output_scales_group_offsets, # (num_groups,)
524523
output_stride_per_block,
525-
output_stride_per_col_of_blocks,
526524
BLOCK_ROWS: tl.constexpr,
527525
BLOCK_COLS: tl.constexpr,
528526
DEBUG: tl.constexpr = False,
@@ -557,8 +555,9 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
557555

558556
# For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales.
559557
# We track how many col blocks we have iterated through.
558+
out_group_base_offset = output_group_start_col * padded_rows
560559
curr_input_start_col = input_group_start_col
561-
curr_out_start_col_block = output_group_start_col // BLOCK_COLS
560+
curr_out_start_col_block = 0
562561
while curr_input_start_col < input_group_end_col:
563562
# Read block of input scales
564563
block_row_offs = block_row_pid * BLOCK_ROWS + row_offs
@@ -570,25 +569,21 @@ def triton_scale_swizzle_per_group_2d2d_lhs(
570569
input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0)
571570
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
572571

573-
# Calculate block offset using provided output block stride
574-
tgt_row_off = block_row_pid * output_stride_per_block
575-
tgt_col_off = curr_out_start_col_block * output_stride_per_col_of_blocks
576-
577-
output_block_offsets = tgt_row_off + tgt_col_off
578-
if DEBUG:
579-
tl.device_print("\nblock_row_pid: ", block_row_pid)
580-
tl.device_print("group_pid: ", group_pid)
581-
tl.device_print("tgt_row_block", block_row_pid)
582-
tl.device_print("output_group_start_col: ", output_group_start_col)
583-
tl.device_print("tgt_col_block", curr_out_start_col_block)
584-
tl.device_print("tgt_row_off: ", tgt_row_off)
585-
tl.device_print("tgt_col_off: ", tgt_col_off)
586-
tl.device_print("global_off:", tgt_row_off + tgt_col_off)
587-
tl.device_print("writing: ", scales_flat)
572+
# Get offset within the group to add to the group's base offset
573+
num_cols_in_group = input_group_end_col - input_group_start_col
574+
num_col_blocks_in_group = tl.cdiv(num_cols_in_group, BLOCK_COLS)
575+
stride_per_row_of_blocks_in_group = (
576+
num_col_blocks_in_group * output_stride_per_block
577+
)
578+
offset_in_group = (
579+
block_row_pid * stride_per_row_of_blocks_in_group
580+
+ curr_out_start_col_block * output_stride_per_block
581+
)
582+
final_offset = out_group_base_offset + offset_in_group
588583

589584
# Apply swizzling for write to gmem
590585
tl.store(
591-
output_scales_ptr + output_block_offsets + dest_indices_flat,
586+
output_scales_ptr + final_offset + dest_indices_flat,
592587
scales_flat,
593588
)
594589

0 commit comments

Comments
 (0)