Skip to content

Commit 381f08c

Browse files
committed
fold batch and sequence dimensions to accelerate Sequence Parallel
ghstack-source-id: 878bd1d Pull Request resolved: #437
1 parent b0ed7f0 commit 381f08c

File tree

3 files changed

+32
-19
lines changed

3 files changed

+32
-19
lines changed

torchtitan/models/llama/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
7979
"""
8080
ndim = x.ndim
8181
assert 0 <= 1 < ndim
82-
seqlen = x.shape[1]
83-
freqs_cis = freqs_cis[0:seqlen]
84-
assert freqs_cis.shape == (seqlen, x.shape[-1])
82+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
8583
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
8684
return freqs_cis.view(*shape)
8785

@@ -187,7 +185,10 @@ def forward(
187185
torch.Tensor: Output tensor after attention.
188186
189187
"""
190-
bs, seqlen, _ = x.shape
188+
# dim 0 of x is a folded dimension of [bs, seqlen]
189+
seqlen, _ = freqs_cis.shape
190+
bs_seqlen, _ = x.shape
191+
bs = bs_seqlen // seqlen
191192
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
192193

193194
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
@@ -209,7 +210,8 @@ def forward(
209210
output = output.transpose(
210211
1, 2
211212
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
212-
output = output.view(bs, seqlen, -1)
213+
# output stay folded with batch and sequence dimension
214+
output = output.view(bs * seqlen, -1)
213215
return self.wo(output)
214216

215217

@@ -427,11 +429,18 @@ def forward(self, tokens: torch.Tensor):
427429
"""
428430
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
429431
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
432+
# fold batch dimension and sequence dimension
433+
# for more efficient allgather/reduce_scatter
434+
h = h.view(-1, self.model_args.dim)
430435

436+
freqs_cis = self.freqs_cis[0 : self.model_args.max_seq_len]
431437
for layer in self.layers.values():
432-
h = layer(h, self.freqs_cis)
438+
h = layer(h, freqs_cis)
433439

434440
h = self.norm(h) if self.norm else h
441+
# unfold batch and sequence dimension
442+
bs, seqlen = tokens.shape
443+
h = h.view(bs, seqlen, self.model_args.dim)
435444
output = self.output(h).float() if self.output else h
436445
return output
437446

torchtitan/models/norms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def _rms_norm_bwd_kernel_sm(
221221
class TritonFusedRMSNorm(torch.autograd.Function):
222222
@partial(
223223
local_map,
224-
out_placements=[Shard(1)],
225-
in_placements=(None, [Shard(1)], [Replicate()], None),
224+
out_placements=[Shard(0)],
225+
in_placements=(None, [Shard(0)], [Replicate()], None),
226226
)
227227
@staticmethod
228228
def forward(ctx, x, weight, eps):
@@ -268,8 +268,8 @@ def forward(ctx, x, weight, eps):
268268

269269
@partial(
270270
local_map,
271-
out_placements=([Shard(1)], [Partial()], None),
272-
in_placements=(None, [Shard(1)]),
271+
out_placements=([Shard(0)], [Partial()], None),
272+
in_placements=(None, [Shard(0)]),
273273
)
274274
@staticmethod
275275
def backward(ctx, dy):

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -350,37 +350,41 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
350350
{
351351
"tok_embeddings": RowwiseParallel(
352352
input_layouts=Replicate(),
353-
output_layouts=Shard(1),
354353
),
355354
"output": col_parallel_strategy(
356-
input_layouts=Shard(1),
355+
input_layouts=Shard(0),
357356
output_layouts=Shard(-1) if loss_parallel else Replicate(),
358357
use_local_output=not loss_parallel,
359358
),
360-
"norm": SequenceParallel(),
359+
"norm": SequenceParallel(sequence_dim=0),
360+
"layers.0": PrepareModuleInput(
361+
input_layouts=(Replicate(), None),
362+
desired_input_layouts=(Shard(0), None),
363+
use_local_output=True,
364+
),
361365
},
362366
)
363367

364368
# Apply tensor + sequence parallelism to every transformer block
365369
for layer_id, transformer_block in model.layers.items():
366370
layer_plan = {
367371
"attention": prepare_module_input(
368-
input_layouts=(Shard(1), None),
372+
input_layouts=(Shard(0), None),
369373
desired_input_layouts=(Replicate(), None),
370374
),
371375
"attention.wq": col_parallel_strategy(),
372376
"attention.wk": col_parallel_strategy(),
373377
"attention.wv": col_parallel_strategy(),
374-
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
375-
"attention_norm": SequenceParallel(),
378+
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
379+
"attention_norm": SequenceParallel(sequence_dim=0),
376380
"feed_forward": prepare_module_input(
377-
input_layouts=(Shard(1),),
381+
input_layouts=(Shard(0),),
378382
desired_input_layouts=(Replicate(),),
379383
),
380384
"feed_forward.w1": col_parallel_strategy(),
381-
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
385+
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
382386
"feed_forward.w3": col_parallel_strategy(),
383-
"ffn_norm": SequenceParallel(),
387+
"ffn_norm": SequenceParallel(sequence_dim=0),
384388
}
385389

386390
# Adjust attention module to use the local number of heads

0 commit comments

Comments
 (0)