Skip to content

Commit a784be2

Browse files
Allow resolutions that are not multiples of 64 (#505)
* Allow resolutions that are not multiples of 64 * ran black * fix bug * add test * more explanation * more comments Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9ebaea5 commit a784be2

File tree

5 files changed

+106
-12
lines changed

5 files changed

+106
-12
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def run(self):
193193
else:
194194
extras["flax"] = deps_list("jax", "jaxlib", "flax")
195195

196-
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
196+
extras["dev"] = (
197+
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
198+
)
197199

198200
install_requires = [
199201
deps["importlib_metadata"],

src/diffusers/models/resnet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,18 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
3434
else:
3535
self.Conv2d_0 = conv
3636

37-
def forward(self, hidden_states):
37+
def forward(self, hidden_states, output_size=None):
3838
assert hidden_states.shape[1] == self.channels
39+
3940
if self.use_conv_transpose:
4041
return self.conv(hidden_states)
4142

42-
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
43+
# if `output_size` is passed we force the interpolation output
44+
# size and do not make use of `scale_factor=2`
45+
if output_size is None:
46+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
47+
else:
48+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
4349

4450
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
4551
if self.use_conv:

src/diffusers/models/unet_2d_condition.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ..configuration_utils import ConfigMixin, register_to_config
99
from ..modeling_utils import ModelMixin
10-
from ..utils import BaseOutput
10+
from ..utils import BaseOutput, logging
1111
from .embeddings import TimestepEmbedding, Timesteps
1212
from .unet_blocks import (
1313
CrossAttnDownBlock2D,
@@ -20,6 +20,9 @@
2020
)
2121

2222

23+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24+
25+
2326
@dataclass
2427
class UNet2DConditionOutput(BaseOutput):
2528
"""
@@ -145,15 +148,25 @@ def __init__(
145148
resnet_groups=norm_num_groups,
146149
)
147150

151+
# count how many layers upsample the images
152+
self.num_upsamplers = 0
153+
148154
# up
149155
reversed_block_out_channels = list(reversed(block_out_channels))
150156
output_channel = reversed_block_out_channels[0]
151157
for i, up_block_type in enumerate(up_block_types):
158+
is_final_block = i == len(block_out_channels) - 1
159+
152160
prev_output_channel = output_channel
153161
output_channel = reversed_block_out_channels[i]
154162
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
155163

156-
is_final_block = i == len(block_out_channels) - 1
164+
# add upsample block for all BUT final layer
165+
if not is_final_block:
166+
add_upsample = True
167+
self.num_upsamplers += 1
168+
else:
169+
add_upsample = False
157170

158171
up_block = get_up_block(
159172
up_block_type,
@@ -162,7 +175,7 @@ def __init__(
162175
out_channels=output_channel,
163176
prev_output_channel=prev_output_channel,
164177
temb_channels=time_embed_dim,
165-
add_upsample=not is_final_block,
178+
add_upsample=add_upsample,
166179
resnet_eps=norm_eps,
167180
resnet_act_fn=act_fn,
168181
resnet_groups=norm_num_groups,
@@ -223,6 +236,20 @@ def forward(
223236
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
224237
returning a tuple, the first element is the sample tensor.
225238
"""
239+
# By default samples have to be AT least a multiple of the overall upsampling factor.
240+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
241+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
242+
# on the fly if necessary.
243+
default_overall_up_factor = 2**self.num_upsamplers
244+
245+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
246+
forward_upsample_size = False
247+
upsample_size = None
248+
249+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
250+
logger.info("Forward upsample size to force interpolation output size.")
251+
forward_upsample_size = True
252+
226253
# 0. center input if necessary
227254
if self.config.center_input_sample:
228255
sample = 2 * sample - 1.0
@@ -262,20 +289,29 @@ def forward(
262289
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
263290

264291
# 5. up
265-
for upsample_block in self.up_blocks:
292+
for i, upsample_block in enumerate(self.up_blocks):
293+
is_final_block = i == len(self.up_blocks) - 1
294+
266295
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
267296
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
268297

298+
# if we have not reached the final block and need to forward the
299+
# upsample size, we do it here
300+
if not is_final_block and forward_upsample_size:
301+
upsample_size = down_block_res_samples[-1].shape[2:]
302+
269303
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
270304
sample = upsample_block(
271305
hidden_states=sample,
272306
temb=emb,
273307
res_hidden_states_tuple=res_samples,
274308
encoder_hidden_states=encoder_hidden_states,
309+
upsample_size=upsample_size,
275310
)
276311
else:
277-
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
278-
312+
sample = upsample_block(
313+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
314+
)
279315
# 6. post-process
280316
# make sure hidden states is in float32
281317
# when running in half-precision

src/diffusers/models/unet_blocks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,7 @@ def forward(
11261126
res_hidden_states_tuple,
11271127
temb=None,
11281128
encoder_hidden_states=None,
1129+
upsample_size=None,
11291130
):
11301131
for resnet, attn in zip(self.resnets, self.attentions):
11311132
# pop res hidden states
@@ -1151,7 +1152,7 @@ def custom_forward(*inputs):
11511152

11521153
if self.upsamplers is not None:
11531154
for upsampler in self.upsamplers:
1154-
hidden_states = upsampler(hidden_states)
1155+
hidden_states = upsampler(hidden_states, upsample_size)
11551156

11561157
return hidden_states
11571158

@@ -1204,7 +1205,7 @@ def __init__(
12041205

12051206
self.gradient_checkpointing = False
12061207

1207-
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1208+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
12081209
for resnet in self.resnets:
12091210
# pop res hidden states
12101211
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1225,7 +1226,7 @@ def custom_forward(*inputs):
12251226

12261227
if self.upsamplers is not None:
12271228
for upsampler in self.upsamplers:
1228-
hidden_states = upsampler(hidden_states)
1229+
hidden_states = upsampler(hidden_states, upsample_size)
12291230

12301231
return hidden_states
12311232

tests/test_pipelines.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,55 @@ def test_stable_diffusion_ddim(self):
336336
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
337337
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
338338

339+
def test_stable_diffusion_ddim_factor_8(self):
340+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
341+
unet = self.dummy_cond_unet
342+
scheduler = DDIMScheduler(
343+
beta_start=0.00085,
344+
beta_end=0.012,
345+
beta_schedule="scaled_linear",
346+
clip_sample=False,
347+
set_alpha_to_one=False,
348+
)
349+
350+
vae = self.dummy_vae
351+
bert = self.dummy_text_encoder
352+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
353+
354+
# make sure here that pndm scheduler skips prk
355+
sd_pipe = StableDiffusionPipeline(
356+
unet=unet,
357+
scheduler=scheduler,
358+
vae=vae,
359+
text_encoder=bert,
360+
tokenizer=tokenizer,
361+
safety_checker=self.dummy_safety_checker,
362+
feature_extractor=self.dummy_extractor,
363+
)
364+
sd_pipe = sd_pipe.to(device)
365+
sd_pipe.set_progress_bar_config(disable=None)
366+
367+
prompt = "A painting of a squirrel eating a burger"
368+
369+
generator = torch.Generator(device=device).manual_seed(0)
370+
output = sd_pipe(
371+
[prompt],
372+
generator=generator,
373+
guidance_scale=6.0,
374+
height=536,
375+
width=536,
376+
num_inference_steps=2,
377+
output_type="np",
378+
)
379+
image = output.images
380+
381+
image_slice = image[0, -3:, -3:, -1]
382+
383+
assert image.shape == (1, 134, 134, 3)
384+
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])
385+
386+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
387+
339388
def test_stable_diffusion_pndm(self):
340389
device = "cpu" # ensure determinism for the device-dependent torch.Generator
341390
unet = self.dummy_cond_unet

0 commit comments

Comments
 (0)