diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index e47d0fb9f8..e357f4323f 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -190,9 +190,12 @@ def forward( bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index bdafc8e215..c21479a24e 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -398,11 +398,6 @@ def apply_tp( "feed_forward.w3": colwise_parallel_weight(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh,