Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class Attention(nn.Module):
Attributes:
n_kv_heads (int): Number of key and value heads.
n_heads (int): Number of query heads.
n_local_kv_heads (int): Number of local key and value heads.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not an attribute (only one occurrence of n_local_kv_heads if you search in this file)

n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (Linear): Linear transformation for queries.
Expand Down Expand Up @@ -183,12 +182,12 @@ def forward(
torch.Tensor: Output tensor after attention.

"""
bsz, seqlen, _ = x.shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all inline comments in this method use bs for batch size so can make this bs for consistency

bs, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Expand All @@ -205,7 +204,7 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
output = output.view(bs, seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -421,7 +420,7 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.

"""
_bsz, seqlen = tokens.shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, _bsz is unused, so just remove

Copy link
Collaborator Author

@awgu awgu May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it helps readability to know the tokens.shape is (batch size, sequence length), I can keep it and maybe rename it to _bs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although not used, it improves code readability -- it tells how many dimensions tokens has, and what they are. So IMO I'd wish they are kept. Also, the "unusedness" has been indicated using the _ prefix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it helps readability to know the tokens.shape is (batch size, sequence length), I can keep it and maybe rename it to _bs?

just saw this message, yeah I agree

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to _bs

_bs, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
Expand Down