|
22 | 22 | triton_fp8_per_group_rowwise_scales, |
23 | 23 | ) |
24 | 24 | from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( |
25 | | - compute_per_group_blocked_scale_offsets, |
26 | | - torch_to_blocked_per_group_2d, |
| 25 | + compute_blocked_scale_offsets_for_K_groups, |
| 26 | + compute_blocked_scale_offsets_for_M_groups, |
| 27 | + torch_to_blocked_2d_K_groups, |
| 28 | + torch_to_blocked_2d_M_groups, |
27 | 29 | torch_to_blocked_per_group_3d, |
28 | | - triton_mx_block_rearrange_per_group_2d, |
| 30 | + triton_mx_block_rearrange_2d_K_groups, |
| 31 | + triton_mx_block_rearrange_2d_M_groups, |
29 | 32 | triton_mx_block_rearrange_per_group_3d, |
30 | 33 | ) |
31 | 34 | from torchao.prototype.moe_training.utils import ( |
@@ -226,15 +229,15 @@ def test_mxfp8_per_group_blocked_scales_2d( |
226 | 229 | ) |
227 | 230 |
|
228 | 231 | # torch reference |
229 | | - ref_out_scales, _ = torch_to_blocked_per_group_2d( |
230 | | - e8m0_scales, input_group_offsets, m, k, block_size=block_size |
| 232 | + ref_out_scales, _ = torch_to_blocked_2d_M_groups( |
| 233 | + e8m0_scales, input_group_offsets, k, block_size=block_size |
231 | 234 | ) |
232 | 235 |
|
233 | 236 | # triton kernel |
234 | | - _, output_group_offsets = compute_per_group_blocked_scale_offsets( |
| 237 | + _, output_group_offsets = compute_blocked_scale_offsets_for_M_groups( |
235 | 238 | input_group_offsets |
236 | 239 | ) |
237 | | - triton_out_scales = triton_mx_block_rearrange_per_group_2d( |
| 240 | + triton_out_scales = triton_mx_block_rearrange_2d_M_groups( |
238 | 241 | e8m0_scales, |
239 | 242 | input_group_offsets, |
240 | 243 | output_group_offsets, |
@@ -266,3 +269,47 @@ def test_mxfp8_per_group_blocked_scales_3d( |
266 | 269 | assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( |
267 | 270 | "blocked scales not equal" |
268 | 271 | ) |
| 272 | + |
| 273 | + |
| 274 | +@skip_if_rocm("ROCm enablement in progress") |
| 275 | +@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) |
| 276 | +@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) |
| 277 | +@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) |
| 278 | +def test_mxfp8_per_group_blocked_scales_2d2d( |
| 279 | + m: int, |
| 280 | + total_k: int, |
| 281 | + n_groups: int, |
| 282 | +): |
| 283 | + device = "cuda" |
| 284 | + block_size = 32 |
| 285 | + input_data = torch.randn(m, total_k, device=device) |
| 286 | + |
| 287 | + e8m0_scales, _ = to_mx( |
| 288 | + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size |
| 289 | + ) |
| 290 | + |
| 291 | + # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets |
| 292 | + input_group_offsets = generate_jagged_offs( |
| 293 | + n_groups, total_k, multiple_of=block_size, device=device |
| 294 | + ) |
| 295 | + input_group_offsets //= block_size |
| 296 | + |
| 297 | + # torch reference |
| 298 | + ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_2d_K_groups( |
| 299 | + e8m0_scales, |
| 300 | + input_group_offsets, |
| 301 | + ) |
| 302 | + |
| 303 | + # triton kernel |
| 304 | + _, output_group_offsets = compute_blocked_scale_offsets_for_K_groups( |
| 305 | + input_group_offsets |
| 306 | + ) |
| 307 | + assert torch.equal(output_group_offsets, ref_start_cols_after_padding), ( |
| 308 | + "output scale group start offsets not equal" |
| 309 | + ) |
| 310 | + triton_out_scales = triton_mx_block_rearrange_2d_K_groups( |
| 311 | + e8m0_scales, |
| 312 | + input_group_offsets, |
| 313 | + output_group_offsets, |
| 314 | + ) |
| 315 | + assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal" |
0 commit comments