Skip to content

Commit 31c58ea

Browse files
committed
add shapes comments
1 parent 75fa029 commit 31c58ea

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/diffusers/models/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def forward(self, hidden_states, context=None):
148148
# note: if no context is given, cross-attention defaults to self-attention
149149
batch, channel, height, weight = hidden_states.shape
150150
residual = hidden_states
151-
hidden_states = self.norm(hidden_states)
152-
hidden_states = self.proj_in(hidden_states)
151+
hidden_states = self.norm(hidden_states) # 2, 320, 64, 64
152+
hidden_states = self.proj_in(hidden_states) # 2, 320, 64, 64
153153
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
154154
for block in self.transformer_blocks:
155-
hidden_states = block(hidden_states, context=context)
155+
hidden_states = block(hidden_states, context=context) # 2, 4096, 320
156156
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
157157
hidden_states = self.proj_out(hidden_states)
158158
return hidden_states + residual
@@ -241,10 +241,10 @@ def __init__(
241241
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
242242

243243
def reshape_heads_to_batch_dim(self, tensor):
244-
batch_size, seq_len, dim = tensor.shape
244+
batch_size, seq_len, dim = tensor.shape # 2, 4096, 320
245245
head_size = self.heads
246-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
247-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
246+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) # 2, 4096, 8, 40
247+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) # 16, 4096, 40
248248
return tensor
249249

250250
def reshape_batch_dim_to_heads(self, tensor):
@@ -271,7 +271,7 @@ def forward(self, hidden_states, context=None, mask=None):
271271
# attention, what we cannot get enough of
272272

273273
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
274-
hidden_states = self._attention(query, key, value)
274+
hidden_states = self._attention(query, key, value) # 2, 4096, 320
275275
else:
276276
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
277277

@@ -286,11 +286,11 @@ def _attention(self, query, key, value):
286286
beta=0,
287287
alpha=self.scale,
288288
)
289-
attention_probs = attention_scores.softmax(dim=-1)
289+
attention_probs = attention_scores.softmax(dim=-1) # 16, 4096, 77
290290
# compute attention output
291-
hidden_states = torch.matmul(attention_probs, value)
291+
hidden_states = torch.matmul(attention_probs, value) # 16, 4096, 40
292292
# reshape hidden_states
293-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
293+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # 2, 4096, 320
294294
return hidden_states
295295

296296
def _sliced_attention(self, query, key, value, sequence_length, dim):

0 commit comments

Comments
 (0)