|
19 | 19 | import numpy as np |
20 | 20 | import torch |
21 | 21 |
|
22 | | -from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler |
| 22 | +from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler |
23 | 23 |
|
24 | 24 |
|
25 | 25 | torch.backends.cuda.matmul.allow_tf32 = False |
@@ -853,3 +853,83 @@ def test_step_shape(self): |
853 | 853 |
|
854 | 854 | self.assertEqual(output_0.shape, sample.shape) |
855 | 855 | self.assertEqual(output_0.shape, output_1.shape) |
| 856 | + |
| 857 | + |
| 858 | +class LMSDiscreteSchedulerTest(SchedulerCommonTest): |
| 859 | + scheduler_classes = (LMSDiscreteScheduler,) |
| 860 | + num_inference_steps = 10 |
| 861 | + |
| 862 | + def get_scheduler_config(self, **kwargs): |
| 863 | + config = { |
| 864 | + "num_train_timesteps": 1100, |
| 865 | + "beta_start": 0.0001, |
| 866 | + "beta_end": 0.02, |
| 867 | + "beta_schedule": "linear", |
| 868 | + "trained_betas": None, |
| 869 | + "tensor_format": "pt", |
| 870 | + } |
| 871 | + |
| 872 | + config.update(**kwargs) |
| 873 | + return config |
| 874 | + |
| 875 | + def test_timesteps(self): |
| 876 | + for timesteps in [10, 50, 100, 1000]: |
| 877 | + self.check_over_configs(num_train_timesteps=timesteps) |
| 878 | + |
| 879 | + def test_betas(self): |
| 880 | + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): |
| 881 | + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) |
| 882 | + |
| 883 | + def test_schedules(self): |
| 884 | + for schedule in ["linear", "scaled_linear"]: |
| 885 | + self.check_over_configs(beta_schedule=schedule) |
| 886 | + |
| 887 | + def test_time_indices(self): |
| 888 | + for t in [0, 500, 800]: |
| 889 | + self.check_over_forward(time_step=t) |
| 890 | + |
| 891 | + def test_pytorch_equal_numpy(self): |
| 892 | + for scheduler_class in self.scheduler_classes: |
| 893 | + sample_pt = self.dummy_sample |
| 894 | + residual_pt = 0.1 * sample_pt |
| 895 | + |
| 896 | + sample = sample_pt.numpy() |
| 897 | + residual = 0.1 * sample |
| 898 | + |
| 899 | + scheduler_config = self.get_scheduler_config() |
| 900 | + scheduler_config["tensor_format"] = "np" |
| 901 | + scheduler = scheduler_class(**scheduler_config) |
| 902 | + |
| 903 | + scheduler_config["tensor_format"] = "pt" |
| 904 | + scheduler_pt = scheduler_class(**scheduler_config) |
| 905 | + |
| 906 | + scheduler.set_timesteps(self.num_inference_steps) |
| 907 | + scheduler_pt.set_timesteps(self.num_inference_steps) |
| 908 | + |
| 909 | + output = scheduler.step(residual, 1, sample).prev_sample |
| 910 | + output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample |
| 911 | + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" |
| 912 | + |
| 913 | + def test_full_loop_no_noise(self): |
| 914 | + scheduler_class = self.scheduler_classes[0] |
| 915 | + scheduler_config = self.get_scheduler_config() |
| 916 | + scheduler = scheduler_class(**scheduler_config) |
| 917 | + |
| 918 | + scheduler.set_timesteps(self.num_inference_steps) |
| 919 | + |
| 920 | + model = self.dummy_model() |
| 921 | + sample = self.dummy_sample_deter * scheduler.sigmas[0] |
| 922 | + |
| 923 | + for i, t in enumerate(scheduler.timesteps): |
| 924 | + sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) |
| 925 | + |
| 926 | + model_output = model(sample, t) |
| 927 | + |
| 928 | + output = scheduler.step(model_output, i, sample) |
| 929 | + sample = output.prev_sample |
| 930 | + |
| 931 | + result_sum = torch.sum(torch.abs(sample)) |
| 932 | + result_mean = torch.mean(torch.abs(sample)) |
| 933 | + |
| 934 | + assert abs(result_sum.item() - 1006.388) < 1e-2 |
| 935 | + assert abs(result_mean.item() - 1.31) < 1e-3 |
0 commit comments