@@ -148,11 +148,11 @@ def forward(self, hidden_states, context=None):
148148 # note: if no context is given, cross-attention defaults to self-attention
149149 batch , channel , height , weight = hidden_states .shape
150150 residual = hidden_states
151- hidden_states = self .norm (hidden_states )
152- hidden_states = self .proj_in (hidden_states )
151+ hidden_states = self .norm (hidden_states ) # 2, 320, 64, 64
152+ hidden_states = self .proj_in (hidden_states ) # 2, 320, 64, 64
153153 hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , channel )
154154 for block in self .transformer_blocks :
155- hidden_states = block (hidden_states , context = context )
155+ hidden_states = block (hidden_states , context = context ) # 2, 4096, 320
156156 hidden_states = hidden_states .reshape (batch , height , weight , channel ).permute (0 , 3 , 1 , 2 )
157157 hidden_states = self .proj_out (hidden_states )
158158 return hidden_states + residual
@@ -241,10 +241,10 @@ def __init__(
241241 self .to_out = nn .Sequential (nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout ))
242242
243243 def reshape_heads_to_batch_dim (self , tensor ):
244- batch_size , seq_len , dim = tensor .shape
244+ batch_size , seq_len , dim = tensor .shape # 2, 4096, 320
245245 head_size = self .heads
246- tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
247- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
246+ tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size ) # 2, 4096, 8, 40
247+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size ) # 16, 4096, 40
248248 return tensor
249249
250250 def reshape_batch_dim_to_heads (self , tensor ):
@@ -271,7 +271,7 @@ def forward(self, hidden_states, context=None, mask=None):
271271 # attention, what we cannot get enough of
272272
273273 if self ._slice_size is None or query .shape [0 ] // self ._slice_size == 1 :
274- hidden_states = self ._attention (query , key , value )
274+ hidden_states = self ._attention (query , key , value ) # 2, 4096, 320
275275 else :
276276 hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim )
277277
@@ -286,11 +286,11 @@ def _attention(self, query, key, value):
286286 beta = 0 ,
287287 alpha = self .scale ,
288288 )
289- attention_probs = attention_scores .softmax (dim = - 1 )
289+ attention_probs = attention_scores .softmax (dim = - 1 ) # 16, 4096, 77
290290 # compute attention output
291- hidden_states = torch .matmul (attention_probs , value )
291+ hidden_states = torch .matmul (attention_probs , value ) # 16, 4096, 40
292292 # reshape hidden_states
293- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
293+ hidden_states = self .reshape_batch_dim_to_heads (hidden_states ) # 2, 4096, 320
294294 return hidden_states
295295
296296 def _sliced_attention (self , query , key , value , sequence_length , dim ):
0 commit comments