|
54 | 54 | logging, |
55 | 55 | ) |
56 | 56 | from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
57 | | -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device |
| 57 | +from diffusers.utils import ( |
| 58 | + CONFIG_NAME, |
| 59 | + WEIGHTS_NAME, |
| 60 | + floats_tensor, |
| 61 | + is_flax_available, |
| 62 | + nightly, |
| 63 | + require_torch_2, |
| 64 | + slow, |
| 65 | + torch_device, |
| 66 | +) |
58 | 67 | from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu |
59 | 68 |
|
60 | 69 |
|
@@ -966,9 +975,41 @@ def test_from_save_pretrained(self): |
966 | 975 | down_block_types=("DownBlock2D", "AttnDownBlock2D"), |
967 | 976 | up_block_types=("AttnUpBlock2D", "UpBlock2D"), |
968 | 977 | ) |
969 | | - schedular = DDPMScheduler(num_train_timesteps=10) |
| 978 | + scheduler = DDPMScheduler(num_train_timesteps=10) |
| 979 | + |
| 980 | + ddpm = DDPMPipeline(model, scheduler) |
| 981 | + ddpm.to(torch_device) |
| 982 | + ddpm.set_progress_bar_config(disable=None) |
| 983 | + |
| 984 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 985 | + ddpm.save_pretrained(tmpdirname) |
| 986 | + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) |
| 987 | + new_ddpm.to(torch_device) |
| 988 | + |
| 989 | + generator = torch.Generator(device=torch_device).manual_seed(0) |
| 990 | + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images |
| 991 | + |
| 992 | + generator = torch.Generator(device=torch_device).manual_seed(0) |
| 993 | + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images |
| 994 | + |
| 995 | + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" |
| 996 | + |
| 997 | + @require_torch_2 |
| 998 | + def test_from_save_pretrained_dynamo(self): |
| 999 | + # 1. Load models |
| 1000 | + model = UNet2DModel( |
| 1001 | + block_out_channels=(32, 64), |
| 1002 | + layers_per_block=2, |
| 1003 | + sample_size=32, |
| 1004 | + in_channels=3, |
| 1005 | + out_channels=3, |
| 1006 | + down_block_types=("DownBlock2D", "AttnDownBlock2D"), |
| 1007 | + up_block_types=("AttnUpBlock2D", "UpBlock2D"), |
| 1008 | + ) |
| 1009 | + model = torch.compile(model) |
| 1010 | + scheduler = DDPMScheduler(num_train_timesteps=10) |
970 | 1011 |
|
971 | | - ddpm = DDPMPipeline(model, schedular) |
| 1012 | + ddpm = DDPMPipeline(model, scheduler) |
972 | 1013 | ddpm.to(torch_device) |
973 | 1014 | ddpm.set_progress_bar_config(disable=None) |
974 | 1015 |
|
|
0 commit comments