Skip to content

Commit b579e87

Browse files
committed
Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel"
Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
1 parent 7ac41b1 commit b579e87

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchtitan/models/llama/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def forward(self, tokens: torch.Tensor):
427427
torch.Tensor: Output logits after applying the Transformer model.
428428
429429
"""
430+
bs = tokens.shape[0]
430431
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage
431432
# fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter
432433
if self.tok_embeddings:
@@ -443,7 +444,7 @@ def forward(self, tokens: torch.Tensor):
443444

444445
h = self.norm(h) if self.norm else h
445446
# unfold batch and sequence dimension
446-
h = h.view(-1, seqlen, self.model_args.dim)
447+
h = h.view(bs, -1, self.model_args.dim)
447448
output = self.output(h).float() if self.output else h
448449
return output
449450

0 commit comments

Comments
 (0)