Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,38 @@ def __init__(
resnet_groups=norm_num_groups,
add_attention=add_attention,
)

# count how many layers upsample the images
self.num_upsamplers = 0


# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1


prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

is_final_block = i == len(block_out_channels) - 1


# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False


up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
Expand Down Expand Up @@ -235,6 +248,25 @@ def forward(
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""



# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
forward_upsample_size = True




# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -288,14 +320,21 @@ def forward(

# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1

res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]


# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]

if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample,upsample_size=upsample_size)
else:
sample = upsample_block(sample, res_samples, emb)
sample = upsample_block(sample, res_samples, emb,upsample_size=upsample_size)

# 6. post-process
sample = self.conv_norm_out(sample)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ def __init__(
else:
self.upsamplers = None

def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None,upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
Expand All @@ -1685,7 +1685,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states,upsample_size)

return hidden_states

Expand Down