From 1e55807316598d5fafcd19200d1d503ee8a05b26 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 29 May 2025 20:40:41 -0700 Subject: [PATCH 1/8] add float8 moe prototype --- torchtitan/components/quantization/float8.py | 17 +++++++++++++++++ torchtitan/config_manager.py | 6 ++++++ .../llama4/train_configs/debug_model.toml | 1 + 3 files changed, 24 insertions(+) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 4a7e2a6512..9d0bcd5694 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -53,6 +53,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True self.filter_fqns = float8_config.filter_fqns + self.moe_fqns = float8_config.moe_fqns if float8_config.recipe_name is not None: assert ( @@ -114,6 +115,22 @@ def convert(self, model: nn.Module): f"{self.config.enable_fsdp_float8_all_gather}" ) + # Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, + # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. + if self.moe_fqns: + from torchao.prototype.scaled_grouped_mm.conversion_utils import ( + convert_moe_to_float8_training, + ) + + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in self.moe_fqns: + if target_fqn in cur_fqn: + return True + return False + + convert_moe_to_float8_training(model, module_filter_fn=moe_module_filter_fn) + logger.info("Converted MoE to float8") + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): if not self.enabled: return diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8299f91a9..c3e55896e5 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -465,6 +465,12 @@ class Float8: Not compatible with torch.compile. """ + moe_fqns: list[str] | str = field(default_factory=list) + """ + Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to. + Example: --float8.moe_fqns="experts" + """ + @dataclass class MX: diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bb038e60c4..9486c08528 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] +moe_fqns = [] From 9325c13a7dcfa2f9dbae49ac602990fcb948c7a0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 5 Jun 2025 13:32:45 -0700 Subject: [PATCH 2/8] use quantize_ for moe --- torchtitan/components/quantization/float8.py | 8 +++++--- torchtitan/train.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 9d0bcd5694..85b5285e9d 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -118,8 +118,9 @@ def convert(self, model: nn.Module): # Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. if self.moe_fqns: + from torchao.quantization.quant_api import quantize_ from torchao.prototype.scaled_grouped_mm.conversion_utils import ( - convert_moe_to_float8_training, + MoETrainingConfig, ) def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: @@ -127,8 +128,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: if target_fqn in cur_fqn: return True return False - - convert_moe_to_float8_training(model, module_filter_fn=moe_module_filter_fn) + + config = MoETrainingConfig(module_filter_fn=moe_module_filter_fn) + quantize_(model, config=config) logger.info("Converted MoE to float8") def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): diff --git a/torchtitan/train.py b/torchtitan/train.py index 14edd70ad4..2221d8e04d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -228,6 +228,7 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) + model = model.to(torch.bfloat16) model.train() self.model_parts = [model] From 75328eaf89f62a42cbede6108144a0488ce2b6ec Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 9 Jun 2025 16:41:18 -0700 Subject: [PATCH 3/8] use filter_fn in quantize_ --- torchtitan/components/quantization/float8.py | 4 ++-- torchtitan/experiments/llama4/model/moe.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 85b5285e9d..d96ea848b9 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -129,8 +129,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return True return False - config = MoETrainingConfig(module_filter_fn=moe_module_filter_fn) - quantize_(model, config=config) + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) logger.info("Converted MoE to float8") def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index a07bf0f7b2..fe07a3acc4 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -82,6 +82,7 @@ def forward( assert ( x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 ), "torch._grouped_mm only supports bf16 dtypes" + h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) h = h * torch._grouped_mm(x, self.w3, offs=offsets) out = torch._grouped_mm(h, self.w2, offs=offsets) From de7cf4ec1e3bf98310a07bea20f2702d029987a5 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 10 Jun 2025 07:30:04 -0700 Subject: [PATCH 4/8] update prototype import path --- torchtitan/components/quantization/float8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index d96ea848b9..b8469b2e3e 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -119,7 +119,7 @@ def convert(self, model: nn.Module): # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. if self.moe_fqns: from torchao.quantization.quant_api import quantize_ - from torchao.prototype.scaled_grouped_mm.conversion_utils import ( + from torchao.prototype.moe_training.conversion_utils import ( MoETrainingConfig, ) From 712828f6ebce4f3cd716385cb50dbbbf61f74a82 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 10 Jun 2025 07:49:21 -0700 Subject: [PATCH 5/8] migrate api name --- torchtitan/components/quantization/float8.py | 13 +++++++++---- torchtitan/config_manager.py | 5 +++-- torchtitan/experiments/llama4/model/moe.py | 1 - 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index b8469b2e3e..ee92863cba 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -53,7 +53,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True self.filter_fqns = float8_config.filter_fqns - self.moe_fqns = float8_config.moe_fqns + self.moe_fqns = float8_config.moe_fqns_prototype if float8_config.recipe_name is not None: assert ( @@ -119,9 +119,14 @@ def convert(self, model: nn.Module): # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. if self.moe_fqns: from torchao.quantization.quant_api import quantize_ - from torchao.prototype.moe_training.conversion_utils import ( - MoETrainingConfig, - ) + try: + from torchao.prototype.moe_training.conversion_utils import ( + MoETrainingConfig, + ) + except ImportError as e: + raise ImportError( + "torchao installation does not have MoE training support. Please install torchao nightly build." + ) def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: for target_fqn in self.moe_fqns: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c3e55896e5..e50a18a657 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -465,10 +465,11 @@ class Float8: Not compatible with torch.compile. """ - moe_fqns: list[str] | str = field(default_factory=list) + moe_fqns_prototype: list[str] | str = field(default_factory=list) """ Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to. - Example: --float8.moe_fqns="experts" + This is a prototype feature that requires the torchao nightly build. + Example: --float8.moe_fqns_prototype="experts" """ diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index fe07a3acc4..a07bf0f7b2 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -82,7 +82,6 @@ def forward( assert ( x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) h = h * torch._grouped_mm(x, self.w3, offs=offsets) out = torch._grouped_mm(h, self.w2, offs=offsets) From 5df35ceedebc74133529aacf250128c3d44c2178 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 10 Jun 2025 07:51:18 -0700 Subject: [PATCH 6/8] remove bf16 hack for single gpu testing --- torchtitan/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 2221d8e04d..14edd70ad4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -228,7 +228,6 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): model.init_weights(buffer_device=buffer_device) - model = model.to(torch.bfloat16) model.train() self.model_parts = [model] From b2c60fafb56d8dc2c1d9ea8522a5b40966d33e81 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 10 Jun 2025 07:53:27 -0700 Subject: [PATCH 7/8] lint --- torchtitan/components/quantization/float8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index ee92863cba..46a0ecd852 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -119,6 +119,7 @@ def convert(self, model: nn.Module): # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. if self.moe_fqns: from torchao.quantization.quant_api import quantize_ + try: from torchao.prototype.moe_training.conversion_utils import ( MoETrainingConfig, @@ -126,14 +127,14 @@ def convert(self, model: nn.Module): except ImportError as e: raise ImportError( "torchao installation does not have MoE training support. Please install torchao nightly build." - ) + ) from e def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: for target_fqn in self.moe_fqns: if target_fqn in cur_fqn: return True return False - + config = MoETrainingConfig() quantize_(model, config=config, filter_fn=moe_module_filter_fn) logger.info("Converted MoE to float8") From 43afdef9f0d487e30861560e7e806a3493875d2a Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 10 Jun 2025 13:59:41 -0700 Subject: [PATCH 8/8] add default moe_fqns --- torchtitan/experiments/llama4/train_configs/debug_model.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 9486c08528..ad36d1ace3 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -69,4 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] -moe_fqns = [] +moe_fqns = ["experts"]