From bb0823623b099d59d5f6ef851594bd40e43a9189 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 1 Aug 2024 19:48:19 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchtitan/models/llama/model.py | 9 ++++++--- torchtitan/parallelisms/parallelize_llama.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda6241d..4f5529a6ae 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -190,9 +190,12 @@ def forward( bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may shard them after + # the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432a3..5c3f161416 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -383,11 +383,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): "ffn_norm": SequenceParallel(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh, From a1ef2b4ccbd28ea0c6d8aa8b485bbc1e40274930 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 1 Aug 2024 19:50:50 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchtitan/models/llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 4f5529a6ae..a2c52dfa18 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -191,8 +191,8 @@ def forward( xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual - # local heads from sizes of xq, xk, and xv as TP may shard them after - # the above linear ops. + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. xq = xq.view(bs, seqlen, -1, self.head_dim) xk = xk.view(bs, seqlen, -1, self.head_dim) xv = xv.view(bs, seqlen, -1, self.head_dim)