Skip to content

Commit 14ca521

Browse files
[mxfp8 moe training] per group scale conversion to blocked format with groups along K dim (for 2d2d grouped gemm) (#2956)
1 parent cc35151 commit 14ca521

File tree

5 files changed

+359
-57
lines changed

5 files changed

+359
-57
lines changed

benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchao.float8.config import ScalingGranularity
1919
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
2020
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
21-
torch_to_blocked_per_group_2d,
21+
torch_to_blocked_2d_M_groups,
2222
torch_to_blocked_per_group_3d,
2323
)
2424
from torchao.prototype.moe_training.utils import generate_jagged_offs
@@ -230,8 +230,8 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
230230

231231
# Convert scales for each group to blocked format.
232232
Mg, K = A_fp8.shape
233-
A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d(
234-
A_scales, offs, Mg, K
233+
A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups(
234+
A_scales, offs, K
235235
)
236236
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)
237237

benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1717
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
18-
compute_per_group_blocked_scale_offsets,
19-
torch_to_blocked_per_group_2d,
20-
triton_mx_block_rearrange_per_group_2d,
18+
compute_blocked_scale_offsets_for_M_groups,
19+
torch_to_blocked_2d_M_groups,
20+
triton_mx_block_rearrange_2d_M_groups,
2121
)
2222
from torchao.prototype.moe_training.utils import generate_jagged_offs
2323

@@ -82,9 +82,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8282
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
8383

8484
# bench torch
85-
compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d)
85+
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
8686
torch_out_scales, torch_group_offs = compiled_run_torch(
87-
input_tensor, input_group_offsets, Mg, K
87+
input_tensor, input_group_offsets, K
8888
)
8989
torch_time_us = benchmark_cuda_function_in_microseconds(
9090
compiled_run_torch,
@@ -95,16 +95,16 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
9595
)
9696

9797
# bench triton
98-
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
98+
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
9999
input_group_offsets
100100
)
101-
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
101+
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
102102
input_tensor,
103103
input_group_offsets,
104104
output_group_offsets,
105105
)
106106
triton_time_us = benchmark_cuda_function_in_microseconds(
107-
triton_mx_block_rearrange_per_group_2d,
107+
triton_mx_block_rearrange_2d_M_groups,
108108
input_tensor,
109109
input_group_offsets,
110110
output_group_offsets,

test/prototype/moe_training/test_kernels.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
triton_fp8_per_group_rowwise_scales,
2323
)
2424
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
25-
compute_per_group_blocked_scale_offsets,
26-
torch_to_blocked_per_group_2d,
25+
compute_blocked_scale_offsets_for_K_groups,
26+
compute_blocked_scale_offsets_for_M_groups,
27+
torch_to_blocked_2d_K_groups,
28+
torch_to_blocked_2d_M_groups,
2729
torch_to_blocked_per_group_3d,
28-
triton_mx_block_rearrange_per_group_2d,
30+
triton_mx_block_rearrange_2d_K_groups,
31+
triton_mx_block_rearrange_2d_M_groups,
2932
triton_mx_block_rearrange_per_group_3d,
3033
)
3134
from torchao.prototype.moe_training.utils import (
@@ -226,15 +229,15 @@ def test_mxfp8_per_group_blocked_scales_2d(
226229
)
227230

228231
# torch reference
229-
ref_out_scales, _ = torch_to_blocked_per_group_2d(
230-
e8m0_scales, input_group_offsets, m, k, block_size=block_size
232+
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
233+
e8m0_scales, input_group_offsets, k, block_size=block_size
231234
)
232235

233236
# triton kernel
234-
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
237+
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
235238
input_group_offsets
236239
)
237-
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
240+
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
238241
e8m0_scales,
239242
input_group_offsets,
240243
output_group_offsets,
@@ -266,3 +269,47 @@ def test_mxfp8_per_group_blocked_scales_3d(
266269
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
267270
"blocked scales not equal"
268271
)
272+
273+
274+
@skip_if_rocm("ROCm enablement in progress")
275+
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
276+
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
277+
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
278+
def test_mxfp8_per_group_blocked_scales_2d2d(
279+
m: int,
280+
total_k: int,
281+
n_groups: int,
282+
):
283+
device = "cuda"
284+
block_size = 32
285+
input_data = torch.randn(m, total_k, device=device)
286+
287+
e8m0_scales, _ = to_mx(
288+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
289+
)
290+
291+
# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
292+
input_group_offsets = generate_jagged_offs(
293+
n_groups, total_k, multiple_of=block_size, device=device
294+
)
295+
input_group_offsets //= block_size
296+
297+
# torch reference
298+
ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_2d_K_groups(
299+
e8m0_scales,
300+
input_group_offsets,
301+
)
302+
303+
# triton kernel
304+
_, output_group_offsets = compute_blocked_scale_offsets_for_K_groups(
305+
input_group_offsets
306+
)
307+
assert torch.equal(output_group_offsets, ref_start_cols_after_padding), (
308+
"output scale group start offsets not equal"
309+
)
310+
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
311+
e8m0_scales,
312+
input_group_offsets,
313+
output_group_offsets,
314+
)
315+
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"

0 commit comments

Comments
 (0)