diff --git a/test/prototype/scaled_grouped_mm/__init__.py b/test/prototype/moe_training/__init__.py similarity index 100% rename from test/prototype/scaled_grouped_mm/__init__.py rename to test/prototype/moe_training/__init__.py diff --git a/test/prototype/scaled_grouped_mm/test_kernels.py b/test/prototype/moe_training/test_kernels.py similarity index 96% rename from test/prototype/scaled_grouped_mm/test_kernels.py rename to test/prototype/moe_training/test_kernels.py index ec18dd45bf..ed68e8fa23 100644 --- a/test/prototype/scaled_grouped_mm/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -19,11 +19,11 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( +from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) -from torchao.prototype.scaled_grouped_mm.utils import ( +from torchao.prototype.moe_training.utils import ( _is_column_major, _to_2d_jagged_float8_tensor_colwise, _to_2d_jagged_float8_tensor_rowwise, diff --git a/test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py similarity index 98% rename from test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py rename to test/prototype/moe_training/test_scaled_grouped_mm.py index 30af1abc04..844220c49c 100644 --- a/test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -26,7 +26,7 @@ from torchao.float8.float8_linear import matmul_with_hp_or_float8_args from torchao.float8.float8_tensor import LinearMMConfig from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated -from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( +from torchao.prototype.moe_training.scaled_grouped_mm import ( _scaled_grouped_mm, ) from torchao.testing.utils import skip_if_rocm diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py new file mode 100644 index 0000000000..71320af83e --- /dev/null +++ b/test/prototype/moe_training/test_training.py @@ -0,0 +1,140 @@ +import copy + +import pytest +import torch +from torch import nn +from torch.nn import functional as F + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.quantization.quant_api import quantize_ + +# this test requires torchtitan +try: + from torchtitan.experiments.llama4.model.args import TransformerModelArgs + from torchtitan.experiments.llama4.model.moe import MoE +except ImportError: + import warnings + + warnings.warn("torchtitan not installed, skipping MoE tests.") + pytest.skip(allow_module_level=True) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["does.not.exist"], + ], +) +def test_moe_float8_training(target_fqns: list[str]): + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(42) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= 30.0, ( + f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= 25.0, ( + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + ) + + +def _validate_model_conversion( + root_module: nn.Module, + target_fqns: list[str], +): + def _recursive_validate( + module: nn.Module, + cur_fqn: str, + ): + is_allowed_module = cur_fqn in target_fqns + + # check current module params + for param_name, param in module.named_parameters(recurse=False): + is_converted_type = isinstance(param, ScaledGroupedMMTensor) + if is_converted_type: + assert is_allowed_module, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + if not is_allowed_module: + assert not is_converted_type, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + + # recursively check child modules + for child_name, child_module in module.named_children(): + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name + _recursive_validate(child_module, child_fqn) + + _recursive_validate(root_module, "") diff --git a/torchao/prototype/moe_training/__init__.py b/torchao/prototype/moe_training/__init__.py new file mode 100644 index 0000000000..8118193aff --- /dev/null +++ b/torchao/prototype/moe_training/__init__.py @@ -0,0 +1,3 @@ +from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm + +__all__ = ["_scaled_grouped_mm"] diff --git a/torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py b/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py similarity index 97% rename from torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py rename to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py index cf40220ae0..37701e6545 100644 --- a/torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py @@ -14,11 +14,11 @@ from tabulate import tabulate from tqdm import tqdm -from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( +from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) -from torchao.prototype.scaled_grouped_mm.utils import ( +from torchao.prototype.moe_training.utils import ( _to_2d_jagged_float8_tensor_colwise, _to_2d_jagged_float8_tensor_rowwise, ) diff --git a/torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py similarity index 98% rename from torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py rename to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py index 74921895ab..af1a652fc0 100644 --- a/torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py @@ -14,7 +14,7 @@ from tabulate import tabulate from tqdm import tqdm -from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm +from torchao.prototype.moe_training import _scaled_grouped_mm device = torch.device("cuda") diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py new file mode 100644 index 0000000000..928af1cf2e --- /dev/null +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -0,0 +1,112 @@ +from typing import Callable, Optional + +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) + + +class MoETrainingConfig(AOBaseConfig): + """ + The MoETrainingConfig is specifically designed to be used on MoE models using + `torch._grouped_mm` to implement expert computation in token-choice routing, + where expert weights are implemented as 3D nn.Parameters wit `num_experts` as + the leading dim. + + MoETrainingConfig has a module handler registered to it which will + find all nn.Parameters whose parent module matches the module filter function, + and swap their data tensor with a ScaledGroupedMMTensor. + + The ScaledGroupedMMTensor is a tensor subclass which overrides the + `torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm, + which performs dynamic float8 rowwise quantization on scaled grouped GEMM + operands in both the forward and backward pass. + + For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. + """ + + pass + + +@register_quantize_module_handler(MoETrainingConfig) +def _moe_training_transform( + module: nn.Module, + config: MoETrainingConfig, +) -> nn.Module: + """ + Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor. + + Args: + module: Module to modify. + config: MoETrainingConfig which defines how to perform the MoE training transform. + + Returns: + nn.Module: The modified module with swapped parameters. + """ + out = _swap_params(module) + return out + + +def _swap_params( + module: nn.Module, + *, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, +) -> nn.Module: + """ + Recurses through the nn.Module, recursively swapping the data tensor of + each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module + passed the module_filter_fn, if specified. + + Args: + module: Module to modify. + module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that + that pass the filter function will be swapped. The inputs to the + filter function are the module instance, and the FQN. + + Returns: + nn.Module: The modified module with swapped linear layers. + """ + if isinstance(module, nn.Parameter) and ( + module_filter_fn is None or module_filter_fn(module, "") + ): + if len(list(module.children())) > 0: + raise AssertionError( + f"Does not support a root nn.Parameter with children: {module}" + ) + if not isinstance(module.data, ScaledGroupedMMTensor): + new_data = ScaledGroupedMMTensor(module.data) + return nn.Parameter(new_data, requires_grad=module.requires_grad) + return module + + root_module = module + + def post_order_traversal( + module: nn.Module, + cur_fqn: Optional[str] = None, + parent_module: Optional[nn.Module] = None, + ): + if cur_fqn is None: + cur_fqn = "" + + for child_module_name, child_module in module.named_children(): + if cur_fqn == "": + new_fqn = child_module_name + else: + new_fqn = f"{cur_fqn}.{child_module_name}" + + post_order_traversal(child_module, new_fqn, module) + + if module_filter_fn is None or module_filter_fn(module, cur_fqn): + for param_name, param in module.named_parameters(recurse=False): + if not isinstance(param.data, ScaledGroupedMMTensor): + new_param = nn.Parameter( + ScaledGroupedMMTensor(param), requires_grad=param.requires_grad + ) + setattr(module, param_name, new_param) + print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor") + + post_order_traversal(root_module) + return root_module diff --git a/torchao/prototype/scaled_grouped_mm/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py similarity index 54% rename from torchao/prototype/scaled_grouped_mm/kernels/__init__.py rename to torchao/prototype/moe_training/kernels/__init__.py index 1c75303568..b5446849b6 100644 --- a/torchao/prototype/scaled_grouped_mm/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -1,6 +1,6 @@ -from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( +from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, ) -from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( +from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, ) diff --git a/torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py similarity index 99% rename from torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py rename to torchao/prototype/moe_training/kernels/jagged_float8_scales.py index 4cc6177a48..3a497bf4a6 100644 --- a/torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -16,7 +16,7 @@ import triton import triton.language as tl -from torchao.prototype.scaled_grouped_mm.utils import _is_column_major +from torchao.prototype.moe_training.utils import _is_column_major EPS = 1e-12 diff --git a/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py similarity index 95% rename from torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py rename to torchao/prototype/moe_training/scaled_grouped_mm.py index 169e2c5407..d3aaf615db 100644 --- a/torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -10,11 +10,11 @@ from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated -from torchao.prototype.scaled_grouped_mm.kernels import ( +from torchao.prototype.moe_training.kernels import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, ) -from torchao.prototype.scaled_grouped_mm.utils import _is_column_major +from torchao.prototype.moe_training.utils import _is_column_major def _scaled_grouped_mm( @@ -83,7 +83,10 @@ def forward( assert not _is_column_major(A), "A must be row-major" # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. - assert _is_column_major(B_t), "B must be column-major" + if not _is_column_major(B_t): + # FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major. + # TODO: figure out better solution than transposing for each forward pass. + B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1) # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. # A shape: (M, K) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py new file mode 100644 index 0000000000..2a929d3b76 --- /dev/null +++ b/torchao/prototype/moe_training/tensor.py @@ -0,0 +1,35 @@ +import torch + +from torchao.prototype.moe_training import _scaled_grouped_mm + + +class ScaledGroupedMMTensor(torch.Tensor): + """ + ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor + and overrides the torch._grouped_mm op by dispatching to the + differentiable _scaled_grouped_mm autograd function. + """ + + grouped_mm_func_name = "_grouped_mm" + offs_arg_name = "offs" + + def __init__(self, data: torch.Tensor): + self._data = data + + @classmethod + def __torch_function__(cls, func, types, args, kwargs={}): + if func.__name__ == cls.grouped_mm_func_name: + # Use torchao scaled grouped mm with dynamic quant for + # "2d x 3d with offsets" case (used for routed experts). + # Otherwise, fall back to regular grouped mm. + # + # TODO: support "3d x 3d without offsets" case, which is + # used for shared experts. This is basically the grouped_mm + # kernel handling a bmm. + A, B = args[0], args[1] + A_is_2d = A.dim() == 2 + B_is_3d = B.dim() == 3 + has_offs = kwargs.get(cls.offs_arg_name) is not None + if A_is_2d and B_is_3d and has_offs: + return _scaled_grouped_mm(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) diff --git a/torchao/prototype/scaled_grouped_mm/utils.py b/torchao/prototype/moe_training/utils.py similarity index 100% rename from torchao/prototype/scaled_grouped_mm/utils.py rename to torchao/prototype/moe_training/utils.py diff --git a/torchao/prototype/scaled_grouped_mm/__init__.py b/torchao/prototype/scaled_grouped_mm/__init__.py deleted file mode 100644 index 9c6278884a..0000000000 --- a/torchao/prototype/scaled_grouped_mm/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( - _scaled_grouped_mm as _scaled_grouped_mm, -)