From 67b8eb240853fc1a5ca4e0ee8c8d4a65eea14cae Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 13 Sep 2022 14:19:37 -0400 Subject: [PATCH 1/6] Allow resolutions that are not multiples of 64 --- src/diffusers/models/resnet.py | 7 +++++-- src/diffusers/models/unet_2d_condition.py | 14 ++++++++++++-- src/diffusers/models/unet_blocks.py | 8 ++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 27fae24f71d8..d18e76b5e25c 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -35,12 +35,15 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, x): + def forward(self, x, size=None): assert x.shape[1] == self.channels if self.use_conv_transpose: return self.conv(x) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if size is None: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + else: + x = F.interpolate(x, size=size, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 92caaca92e24..163d27c5965b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -247,16 +247,26 @@ def forward( res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if upsample_block.upsamplers is not None: + upsample_size = down_block_res_samples[-1].shape[2:] + else: + upsample_size = None + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size ) else: - sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) - + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size + ) # 6. post-process # make sure hidden states is in float32 # when running in half-precision diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 88349075d24f..19f87e350921 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -1072,7 +1072,7 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1084,7 +1084,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hid if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1135,7 +1135,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1146,7 +1146,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states From 826fc3dd125f0fb796ae06c2f28dc07c203915bc Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 20 Sep 2022 12:56:55 -0400 Subject: [PATCH 2/6] ran black --- setup.py | 4 +++- src/diffusers/models/unet_2d_condition.py | 7 ++----- src/diffusers/models/unet_blocks.py | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 6a929d8218fe..04fc7d5cdc67 100644 --- a/setup.py +++ b/setup.py @@ -182,7 +182,9 @@ def run(self): else: extras["flax"] = deps_list("jax", "jaxlib", "flax") -extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] +extras["dev"] = ( + extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] +) install_requires = [ deps["importlib_metadata"], diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 163d27c5965b..24bd67309d83 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -258,14 +258,11 @@ def forward( temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size + upsample_size=upsample_size, ) else: sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process # make sure hidden states is in float32 diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 19f87e350921..9d3618449c65 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -1072,7 +1072,9 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): + def forward( + self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None + ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] From e1d2dcc07225e5756ba9f13d804b035e589c1ecd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 20:48:11 +0000 Subject: [PATCH 3/6] fix bug --- src/diffusers/models/unet_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index dd39c7e1c3fc..a17b1d2a5333 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -1118,6 +1118,8 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) + self.gradient_checkpointing = False + def forward( self, hidden_states, From 1840227803dd51aa494901714a32879183efc237 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 20:57:44 +0000 Subject: [PATCH 4/6] add test --- tests/test_pipelines.py | 49 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 61d5ac3a4e28..4ce82cc732ff 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -336,6 +336,55 @@ def test_stable_diffusion_ddim(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_ddim_factor_8(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + height=536, + width=536, + num_inference_steps=2, + output_type="np", + ) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 134, 134, 3) + expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet From 8e74ce6924f341e3cb82188c6618ff1105f221df Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 21:48:04 +0000 Subject: [PATCH 5/6] more explanation --- src/diffusers/models/unet_2d_condition.py | 43 +++++++++++++++++++---- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 15202f27343d..8b8fae7d2c84 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -7,7 +7,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import ( CrossAttnDownBlock2D, @@ -20,6 +20,9 @@ ) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + @dataclass class UNet2DConditionOutput(BaseOutput): """ @@ -145,15 +148,25 @@ def __init__( resnet_groups=norm_num_groups, ) + # count how many layers upsample the images + self.num_upsamplers = 0 + # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - is_final_block = i == len(block_out_channels) - 1 + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False up_block = get_up_block( up_block_type, @@ -162,7 +175,7 @@ def __init__( out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, - add_upsample=not is_final_block, + add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, @@ -223,6 +236,20 @@ def forward( [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -262,14 +289,16 @@ def forward( sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) # 5. up - for upsample_block in self.up_blocks: + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - if upsample_block.upsamplers is not None: + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - else: - upsample_size = None if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: sample = upsample_block( From cfbc700be2adaf3c77194b9d55608b3a675f4e7f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 21:54:10 +0000 Subject: [PATCH 6/6] more comments --- src/diffusers/models/resnet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index fb46b6096a82..ad2f93d2dbc5 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -34,16 +34,18 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states, size=None): + def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: return self.conv(hidden_states) - if size is None: + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: - hidden_states = F.interpolate(hidden_states, size=size, mode="nearest") + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: