4242 )
4343
4444from torchao .float8 .float8_utils import compute_error
45- from torchao .prototype .moe_training .conversion_utils import MoETrainingConfig
45+ from torchao .prototype .moe_training .conversion_utils import (
46+ MoEScalingType ,
47+ MoETrainingConfig ,
48+ )
4649from torchao .quantization .quant_api import quantize_
4750
4851from .testing_utils import _validate_model_conversion
7174 # ["experts,shared_expert"],
7275 ],
7376)
74- def test_moe_float8_training_tp (target_fqns : list [str ]):
77+ @pytest .mark .parametrize (
78+ "recipe, min_out_sqnr, alignment_size, min_param_grad_sqnr" ,
79+ [
80+ (MoEScalingType .FP8_ROWWISE , 29.0 , 16 , 23.0 ),
81+ (MoEScalingType .MXFP8 , 28.0 , 32 , 21.0 ),
82+ ],
83+ )
84+ def test_moe_float8_training_tp (
85+ target_fqns : list [str ],
86+ recipe : MoEScalingType ,
87+ min_out_sqnr : float ,
88+ alignment_size : int ,
89+ min_param_grad_sqnr : float ,
90+ ):
7591 assert torch .cuda .is_available ()
7692
7793 # token group aligment size must be 16 for fp8
78- set_token_group_alignment_size_m (16 )
94+ set_token_group_alignment_size_m (alignment_size )
7995
8096 # setup distributed for tp
8197 mesh = setup_distributed ()
@@ -108,7 +124,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
108124 return False
109125
110126 # quantize test model
111- config = MoETrainingConfig ()
127+ config = MoETrainingConfig (recipe )
112128 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
113129
114130 # validate that only the experts were converted
@@ -154,7 +170,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
154170
155171 # validate output
156172 out_sqnr = compute_error (out , ref_out )
157- min_out_sqnr = 29.0
158173 assert out_sqnr .item () >= min_out_sqnr , (
159174 f"SQNR must be >= { min_out_sqnr } , got { out_sqnr .item ()} ."
160175 )
@@ -176,7 +191,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
176191 )
177192
178193 # validate param gradients
179- min_param_grad_sqnr = 23.0
180194 for param1 , param2 in zip (model .parameters (), ref_model .parameters ()):
181195 param_grad_sqnr = compute_error (param1 .grad , param2 .grad )
182196 assert param_grad_sqnr .item () >= min_param_grad_sqnr , (
0 commit comments