Skip to content

Commit 8a4c3e5

Browse files
authored
Width was typod as weight (#1800)
* Width was typod as weight * Run Black
1 parent 68e2425 commit 8a4c3e5

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/diffusers/models/attention.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,17 @@ def forward(
204204
"""
205205
# 1. Input
206206
if self.is_input_continuous:
207-
batch, channel, height, weight = hidden_states.shape
207+
batch, channel, height, width = hidden_states.shape
208208
residual = hidden_states
209209

210210
hidden_states = self.norm(hidden_states)
211211
if not self.use_linear_projection:
212212
hidden_states = self.proj_in(hidden_states)
213213
inner_dim = hidden_states.shape[1]
214-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
214+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
215215
else:
216216
inner_dim = hidden_states.shape[1]
217-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
217+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
218218
hidden_states = self.proj_in(hidden_states)
219219
elif self.is_input_vectorized:
220220
hidden_states = self.latent_image_embedding(hidden_states)
@@ -231,15 +231,11 @@ def forward(
231231
# 3. Output
232232
if self.is_input_continuous:
233233
if not self.use_linear_projection:
234-
hidden_states = (
235-
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
236-
)
234+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
237235
hidden_states = self.proj_out(hidden_states)
238236
else:
239237
hidden_states = self.proj_out(hidden_states)
240-
hidden_states = (
241-
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
242-
)
238+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
243239

244240
output = hidden_states + residual
245241
elif self.is_input_vectorized:

0 commit comments

Comments
 (0)