77
88from ..configuration_utils import ConfigMixin , register_to_config
99from ..modeling_utils import ModelMixin
10- from ..utils import BaseOutput
10+ from ..utils import BaseOutput , logging
1111from .embeddings import TimestepEmbedding , Timesteps
1212from .unet_blocks import (
1313 CrossAttnDownBlock2D ,
2020)
2121
2222
23+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
24+
25+
2326@dataclass
2427class UNet2DConditionOutput (BaseOutput ):
2528 """
@@ -145,15 +148,25 @@ def __init__(
145148 resnet_groups = norm_num_groups ,
146149 )
147150
151+ # count how many layers upsample the images
152+ self .num_upsamplers = 0
153+
148154 # up
149155 reversed_block_out_channels = list (reversed (block_out_channels ))
150156 output_channel = reversed_block_out_channels [0 ]
151157 for i , up_block_type in enumerate (up_block_types ):
158+ is_final_block = i == len (block_out_channels ) - 1
159+
152160 prev_output_channel = output_channel
153161 output_channel = reversed_block_out_channels [i ]
154162 input_channel = reversed_block_out_channels [min (i + 1 , len (block_out_channels ) - 1 )]
155163
156- is_final_block = i == len (block_out_channels ) - 1
164+ # add upsample block for all BUT final layer
165+ if not is_final_block :
166+ add_upsample = True
167+ self .num_upsamplers += 1
168+ else :
169+ add_upsample = False
157170
158171 up_block = get_up_block (
159172 up_block_type ,
@@ -162,7 +175,7 @@ def __init__(
162175 out_channels = output_channel ,
163176 prev_output_channel = prev_output_channel ,
164177 temb_channels = time_embed_dim ,
165- add_upsample = not is_final_block ,
178+ add_upsample = add_upsample ,
166179 resnet_eps = norm_eps ,
167180 resnet_act_fn = act_fn ,
168181 resnet_groups = norm_num_groups ,
@@ -223,6 +236,20 @@ def forward(
223236 [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
224237 returning a tuple, the first element is the sample tensor.
225238 """
239+ # By default samples have to be AT least a multiple of the overall upsampling factor.
240+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
241+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
242+ # on the fly if necessary.
243+ default_overall_up_factor = 2 ** self .num_upsamplers
244+
245+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
246+ forward_upsample_size = False
247+ upsample_size = None
248+
249+ if any (s % default_overall_up_factor != 0 for s in sample .shape [- 2 :]):
250+ logger .info ("Forward upsample size to force interpolation output size." )
251+ forward_upsample_size = True
252+
226253 # 0. center input if necessary
227254 if self .config .center_input_sample :
228255 sample = 2 * sample - 1.0
@@ -262,20 +289,29 @@ def forward(
262289 sample = self .mid_block (sample , emb , encoder_hidden_states = encoder_hidden_states )
263290
264291 # 5. up
265- for upsample_block in self .up_blocks :
292+ for i , upsample_block in enumerate (self .up_blocks ):
293+ is_final_block = i == len (self .up_blocks ) - 1
294+
266295 res_samples = down_block_res_samples [- len (upsample_block .resnets ) :]
267296 down_block_res_samples = down_block_res_samples [: - len (upsample_block .resnets )]
268297
298+ # if we have not reached the final block and need to forward the
299+ # upsample size, we do it here
300+ if not is_final_block and forward_upsample_size :
301+ upsample_size = down_block_res_samples [- 1 ].shape [2 :]
302+
269303 if hasattr (upsample_block , "attentions" ) and upsample_block .attentions is not None :
270304 sample = upsample_block (
271305 hidden_states = sample ,
272306 temb = emb ,
273307 res_hidden_states_tuple = res_samples ,
274308 encoder_hidden_states = encoder_hidden_states ,
309+ upsample_size = upsample_size ,
275310 )
276311 else :
277- sample = upsample_block (hidden_states = sample , temb = emb , res_hidden_states_tuple = res_samples )
278-
312+ sample = upsample_block (
313+ hidden_states = sample , temb = emb , res_hidden_states_tuple = res_samples , upsample_size = upsample_size
314+ )
279315 # 6. post-process
280316 # make sure hidden states is in float32
281317 # when running in half-precision
0 commit comments