@@ -91,7 +91,7 @@ def forward(self, hidden_states):
9191
9292 # compute next hidden_states
9393 hidden_states = self .proj_attn (hidden_states )
94- hidden_states = hidden_states .transpose ( - 1 , - 2 ). reshape (batch , channel , height , width )
94+ hidden_states = hidden_states .reshape (batch , channel , height , width )
9595
9696 # res connect and rescale
9797 hidden_states = (hidden_states + residual ) / self .rescale_output_factor
@@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None):
150150 residual = hidden_states
151151 hidden_states = self .norm (hidden_states )
152152 hidden_states = self .proj_in (hidden_states )
153- hidden_states = hidden_states .permute ( 0 , 2 , 3 , 1 ). reshape (batch , height * weight , channel )
153+ hidden_states = hidden_states .reshape (batch , height * weight , channel )
154154 for block in self .transformer_blocks :
155155 hidden_states = block (hidden_states , context = context )
156- hidden_states = hidden_states .reshape (batch , height , weight , channel ). permute ( 0 , 3 , 1 , 2 )
156+ hidden_states = hidden_states .reshape (batch , channel , height , weight )
157157 hidden_states = self .proj_out (hidden_states )
158158 return hidden_states + residual
159159
@@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None):
262262 key = self .to_k (context )
263263 value = self .to_v (context )
264264
265- query = self .reshape_heads_to_batch_dim (query )
266- key = self .reshape_heads_to_batch_dim (key )
267- value = self .reshape_heads_to_batch_dim (value )
265+ # query = self.reshape_heads_to_batch_dim(query)
266+ # key = self.reshape_heads_to_batch_dim(key)
267+ # value = self.reshape_heads_to_batch_dim(value)
268268
269269 # TODO(PVP) - mask is currently never used. Remember to re-implement when used
270270
@@ -290,7 +290,7 @@ def _attention(self, query, key, value):
290290 # compute attention output
291291 hidden_states = torch .matmul (attention_probs , value )
292292 # reshape hidden_states
293- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
293+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
294294 return hidden_states
295295
296296 def _sliced_attention (self , query , key , value , sequence_length , dim ):
@@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
309309 hidden_states [start_idx :end_idx ] = attn_slice
310310
311311 # reshape hidden_states
312- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
312+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
313313 return hidden_states
314314
315315
0 commit comments