Skip to content

Commit c93bc7d

Browse files
authored
add mxfp8 to test_tp (#2870)
1 parent 72222d1 commit c93bc7d

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

test/prototype/moe_training/test_tp.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
)
4343

4444
from torchao.float8.float8_utils import compute_error
45-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
45+
from torchao.prototype.moe_training.conversion_utils import (
46+
MoEScalingType,
47+
MoETrainingConfig,
48+
)
4649
from torchao.quantization.quant_api import quantize_
4750

4851
from .testing_utils import _validate_model_conversion
@@ -71,11 +74,24 @@
7174
# ["experts,shared_expert"],
7275
],
7376
)
74-
def test_moe_float8_training_tp(target_fqns: list[str]):
77+
@pytest.mark.parametrize(
78+
"recipe, min_out_sqnr, alignment_size, min_param_grad_sqnr",
79+
[
80+
(MoEScalingType.FP8_ROWWISE, 29.0, 16, 23.0),
81+
(MoEScalingType.MXFP8, 28.0, 32, 21.0),
82+
],
83+
)
84+
def test_moe_float8_training_tp(
85+
target_fqns: list[str],
86+
recipe: MoEScalingType,
87+
min_out_sqnr: float,
88+
alignment_size: int,
89+
min_param_grad_sqnr: float,
90+
):
7591
assert torch.cuda.is_available()
7692

7793
# token group aligment size must be 16 for fp8
78-
set_token_group_alignment_size_m(16)
94+
set_token_group_alignment_size_m(alignment_size)
7995

8096
# setup distributed for tp
8197
mesh = setup_distributed()
@@ -108,7 +124,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
108124
return False
109125

110126
# quantize test model
111-
config = MoETrainingConfig()
127+
config = MoETrainingConfig(recipe)
112128
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
113129

114130
# validate that only the experts were converted
@@ -154,7 +170,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
154170

155171
# validate output
156172
out_sqnr = compute_error(out, ref_out)
157-
min_out_sqnr = 29.0
158173
assert out_sqnr.item() >= min_out_sqnr, (
159174
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
160175
)
@@ -176,7 +191,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
176191
)
177192

178193
# validate param gradients
179-
min_param_grad_sqnr = 23.0
180194
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
181195
param_grad_sqnr = compute_error(param1.grad, param2.grad)
182196
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (

0 commit comments

Comments
 (0)