@@ -204,17 +204,17 @@ def forward(
204204 """
205205 # 1. Input
206206 if self .is_input_continuous :
207- batch , channel , height , weight = hidden_states .shape
207+ batch , channel , height , width = hidden_states .shape
208208 residual = hidden_states
209209
210210 hidden_states = self .norm (hidden_states )
211211 if not self .use_linear_projection :
212212 hidden_states = self .proj_in (hidden_states )
213213 inner_dim = hidden_states .shape [1 ]
214- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , inner_dim )
214+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
215215 else :
216216 inner_dim = hidden_states .shape [1 ]
217- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , inner_dim )
217+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
218218 hidden_states = self .proj_in (hidden_states )
219219 elif self .is_input_vectorized :
220220 hidden_states = self .latent_image_embedding (hidden_states )
@@ -231,15 +231,11 @@ def forward(
231231 # 3. Output
232232 if self .is_input_continuous :
233233 if not self .use_linear_projection :
234- hidden_states = (
235- hidden_states .reshape (batch , height , weight , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
236- )
234+ hidden_states = hidden_states .reshape (batch , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
237235 hidden_states = self .proj_out (hidden_states )
238236 else :
239237 hidden_states = self .proj_out (hidden_states )
240- hidden_states = (
241- hidden_states .reshape (batch , height , weight , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
242- )
238+ hidden_states = hidden_states .reshape (batch , height , width , inner_dim ).permute (0 , 3 , 1 , 2 ).contiguous ()
243239
244240 output = hidden_states + residual
245241 elif self .is_input_vectorized :
0 commit comments