4242from diffusers .schedulers .scheduling_utils import SCHEDULER_CONFIG_NAME
4343from diffusers .utils import CONFIG_NAME , WEIGHTS_NAME , floats_tensor , slow , torch_device
4444from diffusers .utils .testing_utils import CaptureLogger , get_tests_dir
45+ from parameterized import parameterized
4546from PIL import Image
4647from transformers import CLIPFeatureExtractor , CLIPModel , CLIPTextConfig , CLIPTextModel , CLIPTokenizer
4748
@@ -445,7 +446,9 @@ def test_output_format(self):
445446 assert isinstance (images , list )
446447 assert isinstance (images [0 ], PIL .Image .Image )
447448
448- def test_ddpm_ddim_equality (self ):
449+ # Make sure the test passes for different values of random seed
450+ @parameterized .expand ([(0 ,), (4 ,)])
451+ def test_ddpm_ddim_equality (self , seed ):
449452 model_id = "google/ddpm-cifar10-32"
450453
451454 unet = UNet2DModel .from_pretrained (model_id , device_map = "auto" )
@@ -459,17 +462,24 @@ def test_ddpm_ddim_equality(self):
459462 ddim .to (torch_device )
460463 ddim .set_progress_bar_config (disable = None )
461464
462- generator = torch .manual_seed (0 )
465+ generator = torch .manual_seed (seed )
463466 ddpm_image = ddpm (generator = generator , output_type = "numpy" ).images
464467
465- generator = torch .manual_seed (0 )
466- ddim_image = ddim (generator = generator , num_inference_steps = 1000 , eta = 1.0 , output_type = "numpy" ).images
468+ generator = torch .manual_seed (seed )
469+ ddim_image = ddim (
470+ generator = generator ,
471+ num_inference_steps = 1000 ,
472+ eta = 1.0 ,
473+ output_type = "numpy" ,
474+ use_clipped_model_output = True , # Need this to make DDIM match DDPM
475+ ).images
467476
468477 # the values aren't exactly equal, but the images look the same visually
469478 assert np .abs (ddpm_image - ddim_image ).max () < 1e-1
470479
471- @unittest .skip ("(Anton) The test is failing for large batch sizes, needs investigation" )
472- def test_ddpm_ddim_equality_batched (self ):
480+ # Make sure the test passes for different values of random seed
481+ @parameterized .expand ([(0 ,), (4 ,)])
482+ def test_ddpm_ddim_equality_batched (self , seed ):
473483 model_id = "google/ddpm-cifar10-32"
474484
475485 unet = UNet2DModel .from_pretrained (model_id , device_map = "auto" )
@@ -484,12 +494,17 @@ def test_ddpm_ddim_equality_batched(self):
484494 ddim .to (torch_device )
485495 ddim .set_progress_bar_config (disable = None )
486496
487- generator = torch .manual_seed (0 )
497+ generator = torch .manual_seed (seed )
488498 ddpm_images = ddpm (batch_size = 4 , generator = generator , output_type = "numpy" ).images
489499
490- generator = torch .manual_seed (0 )
500+ generator = torch .manual_seed (seed )
491501 ddim_images = ddim (
492- batch_size = 4 , generator = generator , num_inference_steps = 1000 , eta = 1.0 , output_type = "numpy"
502+ batch_size = 4 ,
503+ generator = generator ,
504+ num_inference_steps = 1000 ,
505+ eta = 1.0 ,
506+ output_type = "numpy" ,
507+ use_clipped_model_output = True , # Need this to make DDIM match DDPM
493508 ).images
494509
495510 # the values aren't exactly equal, but the images look the same visually
0 commit comments