Skip to content

Commit 006ccb8

Browse files
committed
removing all reshapes to test perf
1 parent c0dd0e9 commit 006ccb8

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/diffusers/models/attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(self, hidden_states):
9191

9292
# compute next hidden_states
9393
hidden_states = self.proj_attn(hidden_states)
94-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
94+
hidden_states = hidden_states.reshape(batch, channel, height, width)
9595

9696
# res connect and rescale
9797
hidden_states = (hidden_states + residual) / self.rescale_output_factor
@@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None):
150150
residual = hidden_states
151151
hidden_states = self.norm(hidden_states)
152152
hidden_states = self.proj_in(hidden_states)
153-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
153+
hidden_states = hidden_states.reshape(batch, height * weight, channel)
154154
for block in self.transformer_blocks:
155155
hidden_states = block(hidden_states, context=context)
156-
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
156+
hidden_states = hidden_states.reshape(batch, channel, height, weight)
157157
hidden_states = self.proj_out(hidden_states)
158158
return hidden_states + residual
159159

@@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None):
262262
key = self.to_k(context)
263263
value = self.to_v(context)
264264

265-
query = self.reshape_heads_to_batch_dim(query)
266-
key = self.reshape_heads_to_batch_dim(key)
267-
value = self.reshape_heads_to_batch_dim(value)
265+
# query = self.reshape_heads_to_batch_dim(query)
266+
# key = self.reshape_heads_to_batch_dim(key)
267+
# value = self.reshape_heads_to_batch_dim(value)
268268

269269
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
270270

@@ -290,7 +290,7 @@ def _attention(self, query, key, value):
290290
# compute attention output
291291
hidden_states = torch.matmul(attention_probs, value)
292292
# reshape hidden_states
293-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
293+
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
294294
return hidden_states
295295

296296
def _sliced_attention(self, query, key, value, sequence_length, dim):
@@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
309309
hidden_states[start_idx:end_idx] = attn_slice
310310

311311
# reshape hidden_states
312-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
312+
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
313313
return hidden_states
314314

315315

0 commit comments

Comments
 (0)