Skip to content

Commit a2ff75b

Browse files
[mxfp8 moe training] use dim1 cast cuda kernel in bwd
stack-info: PR: #2897, branch: danielvegamyhre/stack/64
1 parent 2b2af6a commit a2ff75b

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@
4343
@pytest.mark.parametrize(
4444
"recipe_config",
4545
[
46-
# {"recipe": MoEScalingType.FP8_ROWWISE, "group_alignment_size": 16, "min_out_sqnr": 29.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 23.0},
46+
{
47+
"recipe": MoEScalingType.FP8_ROWWISE,
48+
"group_alignment_size": 16,
49+
"min_out_sqnr": 29.0,
50+
"min_input_grad_sqnr": 29.0,
51+
"min_param_grad_sqnr": 23.0,
52+
},
4753
{
4854
"recipe": MoEScalingType.MXFP8,
4955
"group_alignment_size": 32,

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
from torchao.prototype.moe_training.utils import (
2121
_is_column_major,
2222
)
23+
from torchao.prototype.mx_formats.config import (
24+
MXFP8Dim1CastKernelChoice,
25+
MXGemmKernelChoice,
26+
ScaleCalculationMode,
27+
)
28+
from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_dim1_kernel_wrapper
2329
from torchao.prototype.mx_formats.mx_tensor import to_mx
2430

2531
logger: logging.Logger = logging.getLogger(__name__)
@@ -376,17 +382,18 @@ def backward(ctx, grad_out: torch.Tensor):
376382
# Transpose A so we can scale along the M dimension, then un-transpose.
377383
# A_t_data shape: (K, M)
378384
# A_t_scales shape: (K, M//block_size)
379-
A_t_scales, A_t_data = to_mx(
380-
A.transpose(-2, -1).contiguous(),
385+
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
386+
A,
387+
block_size,
381388
elem_dtype=torch.float8_e4m3fn,
382-
block_size=block_size,
383-
)
384-
385-
# A_data shape = (M, K)
386-
A_data = A_t_data.transpose(-2, -1)
387-
388-
# A_scales shape = (M//block_size, K)
389-
A_scales = A_t_scales.transpose(-2, -1)
389+
hp_dtype=A.dtype,
390+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used
391+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
392+
scale_calculation_mode=ScaleCalculationMode.FLOOR,
393+
)
394+
A_mx = A_t_mx.t()
395+
A_data = A_mx.qdata
396+
A_scales = A_mx._scale_e8m0.t()
390397

391398
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
392399
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(

0 commit comments

Comments
 (0)