Skip to content

Commit ee622c0

Browse files
ydshiehPrathik Rao
authored andcommitted
Fix SpatialTransformer (huggingface#578)
* Fix SpatialTransformer * Fix SpatialTransformer Co-authored-by: ydshieh <[email protected]>
1 parent fff1416 commit ee622c0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,11 @@ def forward(self, hidden_states, context=None):
144144
residual = hidden_states
145145
hidden_states = self.norm(hidden_states)
146146
hidden_states = self.proj_in(hidden_states)
147-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
147+
inner_dim = hidden_states.shape[1]
148+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
148149
for block in self.transformer_blocks:
149150
hidden_states = block(hidden_states, context=context)
150-
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
151+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
151152
hidden_states = self.proj_out(hidden_states)
152153
return hidden_states + residual
153154

0 commit comments

Comments
 (0)