Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 35 additions & 38 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
class _A2A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
if isinstance(out_splits, torch.Tensor):
out_splits = out_splits.tolist()
if isinstance(in_splits, torch.Tensor):
in_splits = in_splits.tolist()
T_out = int(sum(out_splits))

y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)

Expand Down Expand Up @@ -176,6 +171,7 @@ def __init__(self):
def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]

# generate the input splits and output splits for all-to-all
with torch.no_grad():
Expand All @@ -187,15 +183,20 @@ def _token_dispatch(self, mod, inputs, device_mesh):
num_tokens_per_expert,
group=device_mesh.get_group(),
)
# NOTE: this would incur a device-to-host sync
self.input_splits = (
num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist()
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
self.output_splits = (
num_tokens_per_expert_group.view(device_mesh.shape[0], -1)
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.tolist()
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
torch.cuda.current_stream().synchronize()
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
Expand Down Expand Up @@ -320,45 +321,41 @@ def wrapper(
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
global TOKEN_GROUP_ALIGN_SIZE_M
if isinstance(w1, DTensor):
w1 = w1.to_local()
w2 = w2.to_local()
w3 = w3.to_local()

if num_tokens_per_expert is not None:
from torchtitan.experiments.kernels.moe.indices import (
generate_permute_indices,
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm understanding, this is upper-bounding the max padding right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

TOKEN_GROUP_ALIGN_SIZE_M,
)

experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

out = func(w1, w2, w3, x, num_tokens_per_expert)

if num_tokens_per_expert is not None:
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]

return out

Expand Down
83 changes: 63 additions & 20 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ def parallelize_llama(
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
ep_degree=parallel_dims.ep,
dp_mod_ep_mesh=(
world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
if dp_mod_ep_mesh_dim_names
if parallel_dims.ep_enabled
else None
),
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
Expand Down Expand Up @@ -273,6 +274,7 @@ def apply_fsdp(
pp_enabled: bool,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
ep_degree: int = 1,
dp_mod_ep_mesh: DeviceMesh | None = None,
gradient_divide_factor: int | None = None,
):
Expand All @@ -298,35 +300,57 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

for layer_id, transformer_block in model.layers.items():
if reshard_after_forward_policy == "always":
match reshard_after_forward_policy:
case "always":
reshard_after_forward = True
elif reshard_after_forward_policy == "never":
case "never":
reshard_after_forward = False
elif reshard_after_forward_policy == "default":
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
else:
case "default":
# For PP, by default do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = not pp_enabled
case _:
raise ValueError(
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
)

# NOTE: in an MoE layer, the router and the shared experts
# are sharded together with the TransformerBlock
if transformer_block.moe_enabled and dp_mod_ep_mesh:
if model.tok_embeddings is not None:
fully_shard(
model.tok_embeddings,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)

for layer_id, transformer_block in model.layers.items():
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
# - the router and the shared experts are sharded together with the TransformerBlock
# - the routed experts are sharded with the remaining dp_mod_ep_mesh
if transformer_block.moe_enabled and ep_degree > 1:
fsdp_mod_ep_config = fsdp_config.copy()
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh

# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
# Even when EP is not used, we may still want to shard the experts
# on non-0 dim. For now it may not be worth the complexity to support
# shard_placement_fn on the outer TransformerBlock-level FSDP.
_experts_shard_placement_fn = None
assert dp_mod_ep_mesh is not None
assert hasattr(transformer_block, "moe")
if (
dp_mod_ep_mesh.size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)

fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)

# NOTE: # Although the FSDP sharding of experts is done on a mesh of
# a different size than other parameters, the gradient division
# factor should be consistent with data.
Expand All @@ -339,7 +363,17 @@ def apply_fsdp(
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
if model.norm is not None and model.output is not None:
fully_shard(
[model.norm, model.output],
**fsdp_config,
reshard_after_forward=reshard_after_forward_policy == "always",
)

fully_shard(model, **fsdp_config)


def apply_moe_ep_tp(
Expand All @@ -366,14 +400,23 @@ def apply_moe_ep_tp(
),
# replicate computation for the router
"moe.router.gate": NoParallel(),
# input Replicate, output Partial
"moe.shared_expert": TensorParallel(),
}
if not etp_enabled:
# If TP is borrowed for EP, then split the tokens across TP ranks so that
# the reorderer, the all-to-all comms, and routed experts computation
# are effectively running Sequence Parallel (split along the folded bs*slen dim)
moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()})
if transformer_block.moe.shared_experts is not None:
# input Replicate, output Partial
moe_layer_plan.update(
{
"moe.shared_experts.w1": ColwiseParallel(),
"moe.shared_experts.w2": RowwiseParallel(
output_layouts=Partial()
),
"moe.shared_experts.w3": ColwiseParallel(),
}
)
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,28 @@ def get_nparams_and_flops(
) -> tuple[int, float]:
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_shared_experts = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.shared_experts" in name:
nparams_shared_experts += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_shared_experts
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def convert_to_titan_fqns(fqn: str) -> list[str]:
elif "feed_forward.router.weight" in fqn:
return [f"layers.{layer}.moe.router.gate.weight"]
elif "feed_forward.shared_expert.down_proj.weight" in fqn:
return [f"layers.{layer}.moe.shared_expert.w2"]
return [f"layers.{layer}.moe.shared_experts.w2.weight"]
elif "feed_forward.shared_expert.gate_proj.weight" in fqn:
return [f"layers.{layer}.moe.shared_expert.w3"]
return [f"layers.{layer}.moe.shared_experts.w3.weight"]
elif "feed_forward.shared_expert.up_proj.weight" in fqn:
return [f"layers.{layer}.moe.shared_expert.w1"]
return [f"layers.{layer}.moe.shared_experts.w1.weight"]
elif "post_attention_layernorm.weight" in fqn:
return [f"layers.{layer}.ffn_norm.weight"]
elif "self_attn.k_proj" in fqn:
Expand All @@ -86,7 +86,7 @@ def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> li
elif "shared_expert" in fqn:
s = dtensor.shape
# TODO: this is not right but I have to do this to load the checkpoint.
return torch.Size((s[2], s[1]))
return torch.Size((s[1], s[0]))
return dtensor.shape


Expand All @@ -96,7 +96,7 @@ def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tenso
elif "shared_expert" in fqn:
# TODO: this is not right but I have to do this to load the checkpoint.
full_tensor = full_tensor.transpose(1, 0)
full_tensors = [full_tensor.unsqueeze(0)]
full_tensors = [full_tensor]
else:
full_tensors = [full_tensor]
return full_tensors
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,28 +126,28 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
"""
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_shared_experts = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.shared_experts" in name:
nparams_shared_experts += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_shared_experts
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
)

Expand Down
Loading
Loading