From f57bdcff8e8f8090eb11f87705b06c628d92b479 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 12 Aug 2025 21:58:15 -0700 Subject: [PATCH] [MoE/EP] apply dim-1 FSDP sharding for routed experts and rewrite shared experts with FFN --- torchtitan/distributed/expert_parallel.py | 73 +++++----- .../experiments/llama4/infra/parallelize.py | 83 ++++++++--- torchtitan/experiments/llama4/model/args.py | 10 +- .../scripts/convert_hf_to_dcp_with_gpus.py | 10 +- torchtitan/models/deepseek_v3/model/args.py | 10 +- torchtitan/models/deepseek_v3/model/model.py | 39 +---- .../deepseek_v3/model/state_dict_adapter.py | 16 +-- torchtitan/models/moe.py | 136 +++++++++--------- 8 files changed, 189 insertions(+), 188 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 915a5ac107..eef4bda714 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -29,12 +29,7 @@ class _A2A(torch.autograd.Function): @staticmethod def forward(ctx, x, out_splits, in_splits, group): - if isinstance(out_splits, torch.Tensor): - out_splits = out_splits.tolist() - if isinstance(in_splits, torch.Tensor): - in_splits = in_splits.tolist() T_out = int(sum(out_splits)) - y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) @@ -176,6 +171,7 @@ def __init__(self): def _token_dispatch(self, mod, inputs, device_mesh): # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs + ep_size = device_mesh.shape[0] # generate the input splits and output splits for all-to-all with torch.no_grad(): @@ -187,15 +183,20 @@ def _token_dispatch(self, mod, inputs, device_mesh): num_tokens_per_expert, group=device_mesh.get_group(), ) - # NOTE: this would incur a device-to-host sync - self.input_splits = ( - num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) ) - self.output_splits = ( - num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) .sum(dim=1) - .tolist() + .to(torch.device("cpu"), non_blocking=True) ) + # NOTE: this would incur a device-to-host sync + torch.cuda.current_stream().synchronize() + self.input_splits = input_splits.tolist() + self.output_splits = output_splits.tolist() # perform all-to-all routed_input = all_to_all_single_autograd( @@ -320,7 +321,7 @@ def wrapper( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: global TOKEN_GROUP_ALIGN_SIZE_M if isinstance(w1, DTensor): @@ -328,37 +329,33 @@ def wrapper( w2 = w2.to_local() w3 = w3.to_local() - if num_tokens_per_expert is not None: - from torchtitan.experiments.kernels.moe.indices import ( - generate_permute_indices, + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, ) - experts_per_ep_rank = w1.shape[0] - num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - - with torch.no_grad(): - ( - permuted_indices, - num_tokens_per_expert, - _, # offsets, - ) = generate_permute_indices( - num_tokens_per_expert, - experts_per_ep_rank, - num_ep_ranks, - x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, - TOKEN_GROUP_ALIGN_SIZE_M, - ) - - x = torch.vstack((x, x.new_zeros((x.shape[-1])))) - input_shape = x.shape - x = x[permuted_indices, :] + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] out = func(w1, w2, w3, x, num_tokens_per_expert) - if num_tokens_per_expert is not None: - out_unpermuted = out.new_empty(input_shape) - out_unpermuted[permuted_indices, :] = out - out = out_unpermuted[:-1] + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] return out diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index bc6f828980..6d75b4986a 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -137,9 +137,10 @@ def parallelize_llama( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if dp_mod_ep_mesh_dim_names + if parallel_dims.ep_enabled else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, @@ -273,6 +274,7 @@ def apply_fsdp( pp_enabled: bool, cpu_offload: bool = False, reshard_after_forward_policy: str = "default", + ep_degree: int = 1, dp_mod_ep_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, ): @@ -298,35 +300,57 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - for layer_id, transformer_block in model.layers.items(): - if reshard_after_forward_policy == "always": + match reshard_after_forward_policy: + case "always": reshard_after_forward = True - elif reshard_after_forward_policy == "never": + case "never": reshard_after_forward = False - elif reshard_after_forward_policy == "default": - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: raise ValueError( f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." ) - # NOTE: in an MoE layer, the router and the shared experts - # are sharded together with the TransformerBlock - if transformer_block.moe_enabled and dp_mod_ep_mesh: + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + for layer_id, transformer_block in model.layers.items(): + # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping + # - the router and the shared experts are sharded together with the TransformerBlock + # - the routed experts are sharded with the remaining dp_mod_ep_mesh + if transformer_block.moe_enabled and ep_degree > 1: fsdp_mod_ep_config = fsdp_config.copy() fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + + # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). + # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding + # causes inefficiency, so we choose to do FSDP sharding on dim-1. + # Even when EP is not used, we may still want to shard the experts + # on non-0 dim. For now it may not be worth the complexity to support + # shard_placement_fn on the outer TransformerBlock-level FSDP. + _experts_shard_placement_fn = None + assert dp_mod_ep_mesh is not None + assert hasattr(transformer_block, "moe") + if ( + dp_mod_ep_mesh.size() * ep_degree + > transformer_block.moe.experts.num_experts + ): + _experts_shard_placement_fn = lambda param: Shard(1) + fully_shard( transformer_block.moe.experts, **fsdp_mod_ep_config, reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, ) + # NOTE: # Although the FSDP sharding of experts is done on a mesh of # a different size than other parameters, the gradient division # factor should be consistent with data. @@ -339,7 +363,17 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + + fully_shard(model, **fsdp_config) def apply_moe_ep_tp( @@ -366,14 +400,23 @@ def apply_moe_ep_tp( ), # replicate computation for the router "moe.router.gate": NoParallel(), - # input Replicate, output Partial - "moe.shared_expert": TensorParallel(), } if not etp_enabled: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + if transformer_block.moe.shared_experts is not None: + # input Replicate, output Partial + moe_layer_plan.update( + { + "moe.shared_experts.w1": ColwiseParallel(), + "moe.shared_experts.w2": RowwiseParallel( + output_layouts=Partial() + ), + "moe.shared_experts.w3": ColwiseParallel(), + } + ) parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index dda130548d..949f4cf052 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -85,7 +85,7 @@ def get_nparams_and_flops( ) -> tuple[int, float]: nparams_embedding = 0 nparams_moe_router = 0 - nparams_shared_expert = 0 + nparams_shared_experts = 0 nparams_experts = 0 nparams_dense = 0 @@ -93,8 +93,8 @@ def get_nparams_and_flops( if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() - elif "moe.shared_expert" in name: - nparams_shared_expert += p.numel() + elif "moe.shared_experts" in name: + nparams_shared_experts += p.numel() elif "moe.router" in name: nparams_moe_router += p.numel() elif "moe.experts" in name: @@ -102,11 +102,11 @@ def get_nparams_and_flops( else: nparams_dense += p.numel() - nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts nparams = nparams_dense + nparams_sparse nparams_sparse_active = ( nparams_moe_router - + nparams_shared_expert + + nparams_shared_experts + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py index bad69c0f7a..5cac0bba3e 100644 --- a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -57,11 +57,11 @@ def convert_to_titan_fqns(fqn: str) -> list[str]: elif "feed_forward.router.weight" in fqn: return [f"layers.{layer}.moe.router.gate.weight"] elif "feed_forward.shared_expert.down_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w2"] + return [f"layers.{layer}.moe.shared_experts.w2.weight"] elif "feed_forward.shared_expert.gate_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w3"] + return [f"layers.{layer}.moe.shared_experts.w3.weight"] elif "feed_forward.shared_expert.up_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w1"] + return [f"layers.{layer}.moe.shared_experts.w1.weight"] elif "post_attention_layernorm.weight" in fqn: return [f"layers.{layer}.ffn_norm.weight"] elif "self_attn.k_proj" in fqn: @@ -86,7 +86,7 @@ def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> li elif "shared_expert" in fqn: s = dtensor.shape # TODO: this is not right but I have to do this to load the checkpoint. - return torch.Size((s[2], s[1])) + return torch.Size((s[1], s[0])) return dtensor.shape @@ -96,7 +96,7 @@ def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tenso elif "shared_expert" in fqn: # TODO: this is not right but I have to do this to load the checkpoint. full_tensor = full_tensor.transpose(1, 0) - full_tensors = [full_tensor.unsqueeze(0)] + full_tensors = [full_tensor] else: full_tensors = [full_tensor] return full_tensors diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 025a550b9b..044420d37a 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -126,7 +126,7 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in """ nparams_embedding = 0 nparams_moe_router = 0 - nparams_shared_expert = 0 + nparams_shared_experts = 0 nparams_experts = 0 nparams_dense = 0 @@ -134,8 +134,8 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() - elif "moe.shared_expert" in name: - nparams_shared_expert += p.numel() + elif "moe.shared_experts" in name: + nparams_shared_experts += p.numel() elif "moe.router" in name: nparams_moe_router += p.numel() elif "moe.experts" in name: @@ -143,11 +143,11 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in else: nparams_dense += p.numel() - nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts nparams = nparams_dense + nparams_sparse nparams_sparse_active = ( nparams_moe_router - + nparams_shared_expert + + nparams_shared_experts + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index cfdc794ca9..dd31fc3181 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,52 +8,15 @@ from typing import Tuple import torch -import torch.nn.functional as F from torch import nn from torchtitan.models.attention import build_attention, init_attention_mask -from torchtitan.models.moe import MoE +from torchtitan.models.moe import FeedForward, MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: """ diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 890ae00f36..5a676b5a07 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -44,9 +44,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", - "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_expert.w1", - "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_expert.w3", - "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_expert.w2", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias", "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", @@ -163,11 +163,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) - - # torchtitan shape: (1, s[1], s[2]) -> HF shape: (s[1], s[2]) - if "shared_expert" in key: - value = value.squeeze(0) - hf_state_dict[new_key] = value else: @@ -217,11 +212,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) - - # HF shape: (s[1], s[2]) -> torchtitan shape: (1, s[1], s[2]) - if "shared_experts" in key: - value = value.unsqueeze(0) - state_dict[new_key] = value else: diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index bd8116ea15..40bd6c2cca 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -31,6 +31,38 @@ class MoEArgs: load_balance_coeff: float | None = 1e-3 +# can be used as dense FFN layer or shared experts in MoE layers +class FeedForward(nn.Module): + """ + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + # TODO: keeping this for-loop implementation for comparison # and readability, may remove later @expert_parallel @@ -39,39 +71,32 @@ def _run_experts_for_loop( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1.transpose(-2, -1))) - h = h * torch.bmm(x, w3.transpose(-2, -1)) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2.transpose(-2, -1)) + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) + h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) return out @@ -82,16 +107,11 @@ def _run_experts_grouped_mm( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 h = F.silu( torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) @@ -122,7 +142,7 @@ def __init__( def forward( self, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: if self.use_grouped_mm: return _run_experts_grouped_mm( @@ -311,15 +331,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): route_scale=moe_args.route_scale, ) self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) - self.shared_expert = ( - GroupedExperts( - dim=dim, - # TODO: if it doesn't use GroupedExperts.num_experts - # we can just use normal FeedForward - hidden_dim=hidden_dim * moe_args.num_shared_experts, - num_experts=1, - use_grouped_mm=moe_args.use_grouped_mm, - ) + self.shared_experts = ( + FeedForward(dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts) if moe_args.num_shared_experts > 0 else None ) @@ -354,6 +367,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ bs, slen, dim = x.shape + x = x.view(-1, dim) # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) @@ -361,7 +375,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: top_scores, selected_experts_indices, num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -391,11 +405,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).expand(-1, dim) # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(-1, dim), - dim=0, - index=token_indices_experts_sorted, - ) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) if self.score_before_experts: routed_input = ( @@ -413,12 +423,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).to(x.dtype) # shared expert - if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) + if self.shared_experts is not None: + out = self.shared_experts(x) else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = torch.zeros_like(x) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output @@ -433,8 +441,8 @@ def init_weights( ): self.experts.init_weights(init_std) self.router.init_weights(init_std) - if self.shared_expert is not None: - self.shared_expert.init_weights(init_std) + if self.shared_experts is not None: + self.shared_experts.init_weights(init_std) if self.load_balance_coeff is not None: with torch.device(buffer_device):