From ccd84c8f48038a8aa542055f6debca75ae744fd1 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 2 May 2024 17:06:24 -0700 Subject: [PATCH 1/2] Renamed `bsz` to `bs` for consistency; removed dead code [ghstack-poisoned] --- torchtitan/models/llama/model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index cfca1cd3c8..c5fd0452ea 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -132,7 +132,6 @@ class Attention(nn.Module): Attributes: n_kv_heads (int): Number of key and value heads. n_heads (int): Number of query heads. - n_local_kv_heads (int): Number of local key and value heads. n_rep (int): Number of repetitions for local heads. head_dim (int): Dimension size of each attention head. wq (Linear): Linear transformation for queries. @@ -183,12 +182,12 @@ def forward( torch.Tensor: Output tensor after attention. """ - bsz, seqlen, _ = x.shape + bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -205,7 +204,7 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bsz, seqlen, -1) + output = output.view(bs, seqlen, -1) return self.wo(output) @@ -421,7 +420,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - _bsz, seqlen = tokens.shape + seqlen = tokens.shape[1] h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[0:seqlen] From 6a3e7b946766b1483609f3d5a14134de04b0f31b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 2 May 2024 17:38:36 -0700 Subject: [PATCH 2/2] Update on "Renamed `bsz` to `bs` for consistency; removed dead code" some minor cleanups [ghstack-poisoned] --- torchtitan/models/llama/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index c5fd0452ea..33fc965d59 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -420,7 +420,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - seqlen = tokens.shape[1] + _bs, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[0:seqlen]