Skip to content

Commit 60645bc

Browse files
Fix EP token group padding issue (#1718)
Fixes #1651 ## Summary - Round up `max_len` of permuted token indicies in expert parallel decorator to be a multiple of token group alignment size. ## Test plan - Llama4 debug model with FSDP=2, EP=2: `NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --compile.enable `
1 parent d66b72a commit 60645bc

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
)
2323
from torch.distributed.tensor.parallel import ParallelStyle
2424

25+
from torchtitan.distributed.utils import _round_up
26+
2527

2628
TOKEN_GROUP_ALIGN_SIZE_M = 8
2729
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
@@ -253,6 +255,12 @@ def wrapper(
253255
experts_per_ep_rank = w1.shape[0]
254256
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
255257

258+
# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
259+
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
260+
x_padded_per_expert = (
261+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M
262+
)
263+
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
256264
with torch.no_grad():
257265
(
258266
permuted_indices,
@@ -262,7 +270,7 @@ def wrapper(
262270
num_tokens_per_expert,
263271
experts_per_ep_rank,
264272
num_ep_ranks,
265-
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
273+
padded_max_len,
266274
TOKEN_GROUP_ALIGN_SIZE_M,
267275
)
268276

torchtitan/distributed/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,9 @@ def _clip_grad_norm_with_ep(
448448
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
449449

450450
return total_norm
451+
452+
453+
def _round_up(x: int, y: int) -> int:
454+
"""Round up x to the nearest multiple of y."""
455+
x_ceil_div_y = (x + y - 1) // y
456+
return x_ceil_div_y * y

0 commit comments

Comments
 (0)