Skip to content

Commit 5cd29d6

Browse files
authored
Fix tests for equivalence of DDIM and DDPM pipelines (#1069)
* Fix equality test for ddim and ddpm * add docs for use_clipped_model_output in DDIM * fix inline comment * reorder imports in test_pipelines.py * Ignore use_clipped_model_output if scheduler doesn't take it
1 parent 1216a3b commit 5cd29d6

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# limitations under the License.
1515

16-
16+
import inspect
1717
from typing import Optional, Tuple, Union
1818

1919
import torch
@@ -44,6 +44,7 @@ def __call__(
4444
generator: Optional[torch.Generator] = None,
4545
eta: float = 0.0,
4646
num_inference_steps: int = 50,
47+
use_clipped_model_output: Optional[bool] = None,
4748
output_type: Optional[str] = "pil",
4849
return_dict: bool = True,
4950
**kwargs,
@@ -60,6 +61,9 @@ def __call__(
6061
num_inference_steps (`int`, *optional*, defaults to 50):
6162
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
6263
expense of slower inference.
64+
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
65+
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
66+
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
6367
output_type (`str`, *optional*, defaults to `"pil"`):
6468
The output format of the generate image. Choose between
6569
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -82,14 +86,22 @@ def __call__(
8286
# set step values
8387
self.scheduler.set_timesteps(num_inference_steps)
8488

89+
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
90+
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
91+
inspect.signature(self.scheduler.step).parameters.keys()
92+
)
93+
extra_kwargs = {}
94+
if accepts_use_clipped_model_output:
95+
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
96+
8597
for t in self.progress_bar(self.scheduler.timesteps):
8698
# 1. predict noise model_output
8799
model_output = self.unet(image, t).sample
88100

89101
# 2. predict previous mean of image x_t-1 and add variance depending on eta
90102
# eta corresponds to η in paper and should be between [0, 1]
91103
# do x_t -> x_t-1
92-
image = self.scheduler.step(model_output, t, image, eta).prev_sample
104+
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
93105

94106
image = (image / 2 + 0.5).clamp(0, 1)
95107
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,10 @@ def step(
220220
sample (`torch.FloatTensor`):
221221
current instance of sample being created by diffusion process.
222222
eta (`float`): weight of noise for added noise in diffusion step.
223-
use_clipped_model_output (`bool`): TODO
223+
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
224+
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
225+
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
226+
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
224227
generator: random number generator.
225228
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
226229

tests/test_pipelines.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
4343
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
4444
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
45+
from parameterized import parameterized
4546
from PIL import Image
4647
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
4748

@@ -445,7 +446,9 @@ def test_output_format(self):
445446
assert isinstance(images, list)
446447
assert isinstance(images[0], PIL.Image.Image)
447448

448-
def test_ddpm_ddim_equality(self):
449+
# Make sure the test passes for different values of random seed
450+
@parameterized.expand([(0,), (4,)])
451+
def test_ddpm_ddim_equality(self, seed):
449452
model_id = "google/ddpm-cifar10-32"
450453

451454
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
@@ -459,17 +462,24 @@ def test_ddpm_ddim_equality(self):
459462
ddim.to(torch_device)
460463
ddim.set_progress_bar_config(disable=None)
461464

462-
generator = torch.manual_seed(0)
465+
generator = torch.manual_seed(seed)
463466
ddpm_image = ddpm(generator=generator, output_type="numpy").images
464467

465-
generator = torch.manual_seed(0)
466-
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
468+
generator = torch.manual_seed(seed)
469+
ddim_image = ddim(
470+
generator=generator,
471+
num_inference_steps=1000,
472+
eta=1.0,
473+
output_type="numpy",
474+
use_clipped_model_output=True, # Need this to make DDIM match DDPM
475+
).images
467476

468477
# the values aren't exactly equal, but the images look the same visually
469478
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
470479

471-
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
472-
def test_ddpm_ddim_equality_batched(self):
480+
# Make sure the test passes for different values of random seed
481+
@parameterized.expand([(0,), (4,)])
482+
def test_ddpm_ddim_equality_batched(self, seed):
473483
model_id = "google/ddpm-cifar10-32"
474484

475485
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
@@ -484,12 +494,17 @@ def test_ddpm_ddim_equality_batched(self):
484494
ddim.to(torch_device)
485495
ddim.set_progress_bar_config(disable=None)
486496

487-
generator = torch.manual_seed(0)
497+
generator = torch.manual_seed(seed)
488498
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
489499

490-
generator = torch.manual_seed(0)
500+
generator = torch.manual_seed(seed)
491501
ddim_images = ddim(
492-
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
502+
batch_size=4,
503+
generator=generator,
504+
num_inference_steps=1000,
505+
eta=1.0,
506+
output_type="numpy",
507+
use_clipped_model_output=True, # Need this to make DDIM match DDPM
493508
).images
494509

495510
# the values aren't exactly equal, but the images look the same visually

0 commit comments

Comments
 (0)