diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 4a7e2a6512..46a0ecd852 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -53,6 +53,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True self.filter_fqns = float8_config.filter_fqns + self.moe_fqns = float8_config.moe_fqns_prototype if float8_config.recipe_name is not None: assert ( @@ -114,6 +115,30 @@ def convert(self, model: nn.Module): f"{self.config.enable_fsdp_float8_all_gather}" ) + # Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, + # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. + if self.moe_fqns: + from torchao.quantization.quant_api import quantize_ + + try: + from torchao.prototype.moe_training.conversion_utils import ( + MoETrainingConfig, + ) + except ImportError as e: + raise ImportError( + "torchao installation does not have MoE training support. Please install torchao nightly build." + ) from e + + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in self.moe_fqns: + if target_fqn in cur_fqn: + return True + return False + + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + logger.info("Converted MoE to float8") + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): if not self.enabled: return diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8299f91a9..e50a18a657 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -465,6 +465,13 @@ class Float8: Not compatible with torch.compile. """ + moe_fqns_prototype: list[str] | str = field(default_factory=list) + """ + Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to. + This is a prototype feature that requires the torchao nightly build. + Example: --float8.moe_fqns_prototype="experts" + """ + @dataclass class MX: diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bb038e60c4..ad36d1ace3 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"]