diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index cfca1cd3c8..33fc965d59 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -132,7 +132,6 @@ class Attention(nn.Module): Attributes: n_kv_heads (int): Number of key and value heads. n_heads (int): Number of query heads. - n_local_kv_heads (int): Number of local key and value heads. n_rep (int): Number of repetitions for local heads. head_dim (int): Dimension size of each attention head. wq (Linear): Linear transformation for queries. @@ -183,12 +182,12 @@ def forward( torch.Tensor: Output tensor after attention. """ - bsz, seqlen, _ = x.shape + bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -205,7 +204,7 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bsz, seqlen, -1) + output = output.view(bs, seqlen, -1) return self.wo(output) @@ -421,7 +420,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - _bsz, seqlen = tokens.shape + _bs, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[0:seqlen]