Skip to content

Commit a54cfe6

Browse files
Add LMSDiscreteSchedulerTest (#467)
* [WIP] add LMSDiscreteSchedulerTest * fixes for comments * add torch numpy test * rebase * Update tests/test_scheduler.py * Update tests/test_scheduler.py * style * return residuals Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 8897217 commit a54cfe6

File tree

1 file changed

+81
-1
lines changed

1 file changed

+81
-1
lines changed

tests/test_scheduler.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020
import torch
2121

22-
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
22+
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler
2323

2424

2525
torch.backends.cuda.matmul.allow_tf32 = False
@@ -853,3 +853,83 @@ def test_step_shape(self):
853853

854854
self.assertEqual(output_0.shape, sample.shape)
855855
self.assertEqual(output_0.shape, output_1.shape)
856+
857+
858+
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
859+
scheduler_classes = (LMSDiscreteScheduler,)
860+
num_inference_steps = 10
861+
862+
def get_scheduler_config(self, **kwargs):
863+
config = {
864+
"num_train_timesteps": 1100,
865+
"beta_start": 0.0001,
866+
"beta_end": 0.02,
867+
"beta_schedule": "linear",
868+
"trained_betas": None,
869+
"tensor_format": "pt",
870+
}
871+
872+
config.update(**kwargs)
873+
return config
874+
875+
def test_timesteps(self):
876+
for timesteps in [10, 50, 100, 1000]:
877+
self.check_over_configs(num_train_timesteps=timesteps)
878+
879+
def test_betas(self):
880+
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
881+
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
882+
883+
def test_schedules(self):
884+
for schedule in ["linear", "scaled_linear"]:
885+
self.check_over_configs(beta_schedule=schedule)
886+
887+
def test_time_indices(self):
888+
for t in [0, 500, 800]:
889+
self.check_over_forward(time_step=t)
890+
891+
def test_pytorch_equal_numpy(self):
892+
for scheduler_class in self.scheduler_classes:
893+
sample_pt = self.dummy_sample
894+
residual_pt = 0.1 * sample_pt
895+
896+
sample = sample_pt.numpy()
897+
residual = 0.1 * sample
898+
899+
scheduler_config = self.get_scheduler_config()
900+
scheduler_config["tensor_format"] = "np"
901+
scheduler = scheduler_class(**scheduler_config)
902+
903+
scheduler_config["tensor_format"] = "pt"
904+
scheduler_pt = scheduler_class(**scheduler_config)
905+
906+
scheduler.set_timesteps(self.num_inference_steps)
907+
scheduler_pt.set_timesteps(self.num_inference_steps)
908+
909+
output = scheduler.step(residual, 1, sample).prev_sample
910+
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
911+
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
912+
913+
def test_full_loop_no_noise(self):
914+
scheduler_class = self.scheduler_classes[0]
915+
scheduler_config = self.get_scheduler_config()
916+
scheduler = scheduler_class(**scheduler_config)
917+
918+
scheduler.set_timesteps(self.num_inference_steps)
919+
920+
model = self.dummy_model()
921+
sample = self.dummy_sample_deter * scheduler.sigmas[0]
922+
923+
for i, t in enumerate(scheduler.timesteps):
924+
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
925+
926+
model_output = model(sample, t)
927+
928+
output = scheduler.step(model_output, i, sample)
929+
sample = output.prev_sample
930+
931+
result_sum = torch.sum(torch.abs(sample))
932+
result_mean = torch.mean(torch.abs(sample))
933+
934+
assert abs(result_sum.item() - 1006.388) < 1e-2
935+
assert abs(result_mean.item() - 1.31) < 1e-3

0 commit comments

Comments
 (0)