Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def run(self):
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")

extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
extras["dev"] = (
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
)

install_requires = [
deps["importlib_metadata"],
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv

def forward(self, hidden_states):
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels

if self.use_conv_transpose:
return self.conv(hidden_states)

hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
Expand Down
48 changes: 42 additions & 6 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock2D,
Expand All @@ -20,6 +20,9 @@
)


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Expand Down Expand Up @@ -145,15 +148,25 @@ def __init__(
resnet_groups=norm_num_groups,
)

# count how many layers upsample the images
self.num_upsamplers = 0

# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1

prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

is_final_block = i == len(block_out_channels) - 1
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False

up_block = get_up_block(
up_block_type,
Expand All @@ -162,7 +175,7 @@ def __init__(
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
Expand Down Expand Up @@ -223,6 +236,20 @@ def forward(
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -262,20 +289,29 @@ def forward(
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)

# 5. up
for upsample_block in self.up_blocks:
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1

res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]

if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)

sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ def forward(
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
Expand All @@ -1151,7 +1152,7 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states

Expand Down Expand Up @@ -1204,7 +1205,7 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
Expand All @@ -1225,7 +1226,7 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states

Expand Down
49 changes: 49 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,55 @@ def test_stable_diffusion_ddim(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)

vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"

generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
height=536,
width=536,
num_inference_steps=2,
output_type="np",
)
image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 134, 134, 3)
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down