Skip to content

Commit 2a81dc1

Browse files
committed
fold batch and sequence dimensions to accelerate Sequence Parallel
ghstack-source-id: e777f30 Pull Request resolved: #437
1 parent b0ed7f0 commit 2a81dc1

File tree

3 files changed

+33
-21
lines changed

3 files changed

+33
-21
lines changed

torchtitan/models/llama/model.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
7979
"""
8080
ndim = x.ndim
8181
assert 0 <= 1 < ndim
82-
seqlen = x.shape[1]
83-
freqs_cis = freqs_cis[0:seqlen]
84-
assert freqs_cis.shape == (seqlen, x.shape[-1])
82+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
8583
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
8684
return freqs_cis.view(*shape)
8785

@@ -187,7 +185,10 @@ def forward(
187185
torch.Tensor: Output tensor after attention.
188186
189187
"""
190-
bs, seqlen, _ = x.shape
188+
# dim 0 of x is a folded dimension of (bs, seqlen)
189+
seqlen, _ = freqs_cis.shape
190+
bs_seqlen, _ = x.shape
191+
bs = bs_seqlen // seqlen
191192
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
192193

193194
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
@@ -209,7 +210,8 @@ def forward(
209210
output = output.transpose(
210211
1, 2
211212
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
212-
output = output.view(bs, seqlen, -1)
213+
# output stay folded with batch and sequence dimension
214+
output = output.view(bs * seqlen, -1)
213215
return self.wo(output)
214216

215217

@@ -425,13 +427,23 @@ def forward(self, tokens: torch.Tensor):
425427
torch.Tensor: Output logits after applying the Transformer model.
426428
427429
"""
428-
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
429-
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
430+
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage
431+
# fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter
432+
if self.tok_embeddings:
433+
tokens = tokens.view(-1)
434+
h = self.tok_embeddings(tokens)
435+
else:
436+
h = tokens
437+
h = h.view(-1, self.model_args.dim)
430438

439+
seqlen = self.model_args.max_seq_len
440+
freqs_cis = self.freqs_cis[0:seqlen]
431441
for layer in self.layers.values():
432-
h = layer(h, self.freqs_cis)
442+
h = layer(h, freqs_cis)
433443

434444
h = self.norm(h) if self.norm else h
445+
# unfold batch and sequence dimension
446+
h = h.view(-1, seqlen, self.model_args.dim)
435447
output = self.output(h).float() if self.output else h
436448
return output
437449

torchtitan/models/norms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def _rms_norm_bwd_kernel_sm(
221221
class TritonFusedRMSNorm(torch.autograd.Function):
222222
@partial(
223223
local_map,
224-
out_placements=[Shard(1)],
225-
in_placements=(None, [Shard(1)], [Replicate()], None),
224+
out_placements=[Shard(0)],
225+
in_placements=(None, [Shard(0)], [Replicate()], None),
226226
)
227227
@staticmethod
228228
def forward(ctx, x, weight, eps):
@@ -268,8 +268,8 @@ def forward(ctx, x, weight, eps):
268268

269269
@partial(
270270
local_map,
271-
out_placements=([Shard(1)], [Partial()], None),
272-
in_placements=(None, [Shard(1)]),
271+
out_placements=([Shard(0)], [Partial()], None),
272+
in_placements=(None, [Shard(0)]),
273273
)
274274
@staticmethod
275275
def backward(ctx, dy):

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -350,37 +350,37 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
350350
{
351351
"tok_embeddings": RowwiseParallel(
352352
input_layouts=Replicate(),
353-
output_layouts=Shard(1),
353+
output_layouts=Shard(0),
354354
),
355355
"output": col_parallel_strategy(
356-
input_layouts=Shard(1),
356+
input_layouts=Shard(0),
357357
output_layouts=Shard(-1) if loss_parallel else Replicate(),
358358
use_local_output=not loss_parallel,
359359
),
360-
"norm": SequenceParallel(),
360+
"norm": SequenceParallel(sequence_dim=0),
361361
},
362362
)
363363

364364
# Apply tensor + sequence parallelism to every transformer block
365365
for layer_id, transformer_block in model.layers.items():
366366
layer_plan = {
367367
"attention": prepare_module_input(
368-
input_layouts=(Shard(1), None),
368+
input_layouts=(Shard(0), None),
369369
desired_input_layouts=(Replicate(), None),
370370
),
371371
"attention.wq": col_parallel_strategy(),
372372
"attention.wk": col_parallel_strategy(),
373373
"attention.wv": col_parallel_strategy(),
374-
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
375-
"attention_norm": SequenceParallel(),
374+
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
375+
"attention_norm": SequenceParallel(sequence_dim=0),
376376
"feed_forward": prepare_module_input(
377-
input_layouts=(Shard(1),),
377+
input_layouts=(Shard(0),),
378378
desired_input_layouts=(Replicate(),),
379379
),
380380
"feed_forward.w1": col_parallel_strategy(),
381-
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
381+
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
382382
"feed_forward.w3": col_parallel_strategy(),
383-
"ffn_norm": SequenceParallel(),
383+
"ffn_norm": SequenceParallel(sequence_dim=0),
384384
}
385385

386386
# Adjust attention module to use the local number of heads

0 commit comments

Comments
 (0)