Skip to content

Commit 83c206c

Browse files
committed
[TP] Infer local n_heads instead of ad-hoc model changes
ghstack-source-id: 77baf1a Pull Request resolved: #498
1 parent ea8c5c8 commit 83c206c

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

torchtitan/models/llama/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,12 @@ def forward(
190190
bs, seqlen, _ = x.shape
191191
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
192192

193-
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
194-
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
195-
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)
193+
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
194+
# local heads from sizes of xq, xk, and xv as TP may shard them after
195+
# the above linear ops.
196+
xq = xq.view(bs, seqlen, -1, self.head_dim)
197+
xk = xk.view(bs, seqlen, -1, self.head_dim)
198+
xv = xv.view(bs, seqlen, -1, self.head_dim)
196199

197200
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
198201

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
383383
"ffn_norm": SequenceParallel(),
384384
}
385385

386-
# Adjust attention module to use the local number of heads
387-
attn_layer = transformer_block.attention
388-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
389-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
390-
391386
parallelize_module(
392387
module=transformer_block,
393388
device_mesh=tp_mesh,

0 commit comments

Comments
 (0)