|
19 | 19 | import torch |
20 | 20 |
|
21 | 21 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel |
| 22 | +from diffusers.utils import deprecate |
22 | 23 | from diffusers.utils.testing_utils import require_torch, slow, torch_device |
23 | 24 |
|
24 | 25 | from ...test_pipelines_common import PipelineTesterMixin |
|
28 | 29 |
|
29 | 30 |
|
30 | 31 | class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
31 | | - # FIXME: add fast tests |
32 | | - pass |
| 32 | + @property |
| 33 | + def dummy_uncond_unet(self): |
| 34 | + torch.manual_seed(0) |
| 35 | + model = UNet2DModel( |
| 36 | + block_out_channels=(32, 64), |
| 37 | + layers_per_block=2, |
| 38 | + sample_size=32, |
| 39 | + in_channels=3, |
| 40 | + out_channels=3, |
| 41 | + down_block_types=("DownBlock2D", "AttnDownBlock2D"), |
| 42 | + up_block_types=("AttnUpBlock2D", "UpBlock2D"), |
| 43 | + ) |
| 44 | + return model |
| 45 | + |
| 46 | + def test_inference(self): |
| 47 | + unet = self.dummy_uncond_unet |
| 48 | + scheduler = DDPMScheduler() |
| 49 | + |
| 50 | + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) |
| 51 | + ddpm.to(torch_device) |
| 52 | + ddpm.set_progress_bar_config(disable=None) |
| 53 | + |
| 54 | + # Warmup pass when using mps (see #372) |
| 55 | + if torch_device == "mps": |
| 56 | + _ = ddpm(num_inference_steps=1) |
| 57 | + |
| 58 | + generator = torch.manual_seed(0) |
| 59 | + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images |
| 60 | + |
| 61 | + generator = torch.manual_seed(0) |
| 62 | + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] |
| 63 | + |
| 64 | + image_slice = image[0, -3:, -3:, -1] |
| 65 | + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] |
| 66 | + |
| 67 | + assert image.shape == (1, 32, 32, 3) |
| 68 | + expected_slice = np.array( |
| 69 | + [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] |
| 70 | + ) |
| 71 | + tolerance = 1e-2 if torch_device != "mps" else 3e-2 |
| 72 | + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance |
| 73 | + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance |
| 74 | + |
| 75 | + def test_inference_predict_epsilon(self): |
| 76 | + deprecate("remove this test", "0.10.0", "remove") |
| 77 | + unet = self.dummy_uncond_unet |
| 78 | + scheduler = DDPMScheduler(predict_epsilon=False) |
| 79 | + |
| 80 | + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) |
| 81 | + ddpm.to(torch_device) |
| 82 | + ddpm.set_progress_bar_config(disable=None) |
| 83 | + |
| 84 | + # Warmup pass when using mps (see #372) |
| 85 | + if torch_device == "mps": |
| 86 | + _ = ddpm(num_inference_steps=1) |
| 87 | + |
| 88 | + generator = torch.manual_seed(0) |
| 89 | + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images |
| 90 | + |
| 91 | + generator = torch.manual_seed(0) |
| 92 | + image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] |
| 93 | + |
| 94 | + image_slice = image[0, -3:, -3:, -1] |
| 95 | + image_eps_slice = image_eps[0, -3:, -3:, -1] |
| 96 | + |
| 97 | + assert image.shape == (1, 32, 32, 3) |
| 98 | + tolerance = 1e-2 if torch_device != "mps" else 3e-2 |
| 99 | + assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance |
33 | 100 |
|
34 | 101 |
|
35 | 102 | @slow |
|
0 commit comments