Skip to content

Commit f806068

Browse files
committed
add pytorch unit test for dpmsolver
1 parent fda444a commit f806068

File tree

3 files changed

+193
-2
lines changed

3 files changed

+193
-2
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ def step(
396396

397397
if isinstance(timestep, torch.Tensor):
398398
timestep = timestep.to(self.timesteps.device)
399-
step_index = (self.timesteps == timestep).nonzero().item()
399+
step_index = (self.timesteps == timestep).nonzero()
400+
if len(step_index) == 0:
401+
step_index = len(self.timesteps) - 1
402+
else:
403+
step_index = step_index.item()
400404
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
401405
denoise_final = (step_index == len(self.timesteps) - 1) and self.denoise_final
402406
denoise_second = (step_index == len(self.timesteps) - 2) and self.denoise_final

tests/test_config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
import unittest
2020

2121
import diffusers
22-
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging
22+
from diffusers import (
23+
DDIMScheduler,
24+
DPMSolverDiscreteScheduler,
25+
EulerAncestralDiscreteScheduler,
26+
EulerDiscreteScheduler,
27+
PNDMScheduler,
28+
logging,
29+
)
2330
from diffusers.configuration_utils import ConfigMixin, register_to_config
2431
from diffusers.utils.testing_utils import CaptureLogger
2532

@@ -283,3 +290,15 @@ def test_load_pndm(self):
283290
assert pndm.__class__ == PNDMScheduler
284291
# no warning should be thrown
285292
assert cap_logger.out == ""
293+
294+
def test_load_dpmsolver(self):
295+
logger = logging.get_logger("diffusers.configuration_utils")
296+
297+
with CaptureLogger(logger) as cap_logger:
298+
pndm = DPMSolverDiscreteScheduler.from_config(
299+
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
300+
)
301+
302+
assert pndm.__class__ == DPMSolverDiscreteScheduler
303+
# no warning should be thrown
304+
assert cap_logger.out == ""

tests/test_scheduler.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers import (
2525
DDIMScheduler,
2626
DDPMScheduler,
27+
DPMSolverDiscreteScheduler,
2728
EulerAncestralDiscreteScheduler,
2829
EulerDiscreteScheduler,
2930
IPNDMScheduler,
@@ -549,6 +550,173 @@ def test_full_loop_with_no_set_alpha_to_one(self):
549550
assert abs(result_mean.item() - 0.1941) < 1e-3
550551

551552

553+
class DPMSolverDiscreteSchedulerTest(SchedulerCommonTest):
554+
scheduler_classes = (DPMSolverDiscreteScheduler,)
555+
forward_default_kwargs = (("num_inference_steps", 25),)
556+
557+
def get_scheduler_config(self, **kwargs):
558+
config = {
559+
"num_train_timesteps": 1000,
560+
"beta_start": 0.0001,
561+
"beta_end": 0.02,
562+
"beta_schedule": "linear",
563+
"solver_order": 2,
564+
"predict_x0": True,
565+
"thresholding": False,
566+
"sample_max_value": 1.0,
567+
"solver_type": "dpm_solver",
568+
"denoise_final": False,
569+
}
570+
571+
config.update(**kwargs)
572+
return config
573+
574+
def check_over_configs(self, time_step=0, **config):
575+
kwargs = dict(self.forward_default_kwargs)
576+
num_inference_steps = kwargs.pop("num_inference_steps", None)
577+
sample = self.dummy_sample
578+
residual = 0.1 * sample
579+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
580+
581+
for scheduler_class in self.scheduler_classes:
582+
scheduler_config = self.get_scheduler_config(**config)
583+
scheduler = scheduler_class(**scheduler_config)
584+
scheduler.set_timesteps(num_inference_steps)
585+
# copy over dummy past residuals
586+
scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order]
587+
588+
with tempfile.TemporaryDirectory() as tmpdirname:
589+
scheduler.save_config(tmpdirname)
590+
new_scheduler = scheduler_class.from_config(tmpdirname)
591+
new_scheduler.set_timesteps(num_inference_steps)
592+
# copy over dummy past residuals
593+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order]
594+
595+
output, new_output = sample, sample
596+
for t in range(time_step, time_step + scheduler.solver_order + 1):
597+
output = scheduler.step(residual, t, output, **kwargs).prev_sample
598+
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
599+
600+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
601+
602+
def test_from_pretrained_save_pretrained(self):
603+
pass
604+
605+
def check_over_forward(self, time_step=0, **forward_kwargs):
606+
kwargs = dict(self.forward_default_kwargs)
607+
num_inference_steps = kwargs.pop("num_inference_steps", None)
608+
sample = self.dummy_sample
609+
residual = 0.1 * sample
610+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
611+
612+
for scheduler_class in self.scheduler_classes:
613+
scheduler_config = self.get_scheduler_config()
614+
scheduler = scheduler_class(**scheduler_config)
615+
scheduler.set_timesteps(num_inference_steps)
616+
617+
# copy over dummy past residuals (must be after setting timesteps)
618+
scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order]
619+
620+
with tempfile.TemporaryDirectory() as tmpdirname:
621+
scheduler.save_config(tmpdirname)
622+
new_scheduler = scheduler_class.from_config(tmpdirname)
623+
# copy over dummy past residuals
624+
new_scheduler.set_timesteps(num_inference_steps)
625+
626+
# copy over dummy past residual (must be after setting timesteps)
627+
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order]
628+
629+
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
630+
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
631+
632+
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
633+
634+
def full_loop(self, **config):
635+
scheduler_class = self.scheduler_classes[0]
636+
scheduler_config = self.get_scheduler_config(**config)
637+
scheduler = scheduler_class(**scheduler_config)
638+
639+
num_inference_steps = 10
640+
model = self.dummy_model()
641+
sample = self.dummy_sample_deter
642+
scheduler.set_timesteps(num_inference_steps)
643+
644+
for i, t in enumerate(scheduler.timesteps):
645+
residual = model(sample, t)
646+
sample = scheduler.step(residual, t, sample).prev_sample
647+
648+
return sample
649+
650+
def test_step_shape(self):
651+
kwargs = dict(self.forward_default_kwargs)
652+
653+
num_inference_steps = kwargs.pop("num_inference_steps", None)
654+
655+
for scheduler_class in self.scheduler_classes:
656+
scheduler_config = self.get_scheduler_config()
657+
scheduler = scheduler_class(**scheduler_config)
658+
659+
sample = self.dummy_sample
660+
residual = 0.1 * sample
661+
662+
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
663+
scheduler.set_timesteps(num_inference_steps)
664+
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
665+
kwargs["num_inference_steps"] = num_inference_steps
666+
667+
# copy over dummy past residuals (must be done after set_timesteps)
668+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
669+
scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order]
670+
671+
time_step_0 = scheduler.timesteps[5]
672+
time_step_1 = scheduler.timesteps[6]
673+
674+
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
675+
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
676+
677+
self.assertEqual(output_0.shape, sample.shape)
678+
self.assertEqual(output_0.shape, output_1.shape)
679+
680+
def test_timesteps(self):
681+
for timesteps in [25, 50, 100, 999, 1000]:
682+
self.check_over_configs(num_train_timesteps=timesteps)
683+
684+
def test_thresholding(self):
685+
self.check_over_configs(thresholding=False)
686+
for order in [1, 2, 3]:
687+
for solver_type in ["dpm_solver", "taylor"]:
688+
for threshold in [0.5, 1.0, 2.0]:
689+
self.check_over_configs(
690+
thresholding=True,
691+
sample_max_value=threshold,
692+
predict_x0=True,
693+
solver_order=order,
694+
solver_type=solver_type,
695+
)
696+
697+
def test_solver_order_and_type(self):
698+
for solver_type in ["dpm_solver", "taylor"]:
699+
for order in [1, 2, 3]:
700+
for predict_x0 in [True, False]:
701+
self.check_over_configs(solver_order=order, solver_type=solver_type, predict_x0=predict_x0)
702+
sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0)
703+
assert not torch.isnan(sample).any(), "Samples have nan numbers"
704+
705+
def test_denoise_final(self):
706+
self.check_over_configs(denoise_final=True)
707+
self.check_over_configs(denoise_final=False)
708+
709+
def test_inference_steps(self):
710+
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
711+
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
712+
713+
def test_full_loop_no_noise(self):
714+
sample = self.full_loop()
715+
result_mean = torch.mean(torch.abs(sample))
716+
717+
assert abs(result_mean.item() - 0.3301) < 1e-3
718+
719+
552720
class PNDMSchedulerTest(SchedulerCommonTest):
553721
scheduler_classes = (PNDMScheduler,)
554722
forward_default_kwargs = (("num_inference_steps", 50),)

0 commit comments

Comments
 (0)