|
24 | 24 | from diffusers import ( |
25 | 25 | DDIMScheduler, |
26 | 26 | DDPMScheduler, |
| 27 | + DPMSolverDiscreteScheduler, |
27 | 28 | EulerAncestralDiscreteScheduler, |
28 | 29 | EulerDiscreteScheduler, |
29 | 30 | IPNDMScheduler, |
@@ -549,6 +550,173 @@ def test_full_loop_with_no_set_alpha_to_one(self): |
549 | 550 | assert abs(result_mean.item() - 0.1941) < 1e-3 |
550 | 551 |
|
551 | 552 |
|
| 553 | +class DPMSolverDiscreteSchedulerTest(SchedulerCommonTest): |
| 554 | + scheduler_classes = (DPMSolverDiscreteScheduler,) |
| 555 | + forward_default_kwargs = (("num_inference_steps", 25),) |
| 556 | + |
| 557 | + def get_scheduler_config(self, **kwargs): |
| 558 | + config = { |
| 559 | + "num_train_timesteps": 1000, |
| 560 | + "beta_start": 0.0001, |
| 561 | + "beta_end": 0.02, |
| 562 | + "beta_schedule": "linear", |
| 563 | + "solver_order": 2, |
| 564 | + "predict_x0": True, |
| 565 | + "thresholding": False, |
| 566 | + "sample_max_value": 1.0, |
| 567 | + "solver_type": "dpm_solver", |
| 568 | + "denoise_final": False, |
| 569 | + } |
| 570 | + |
| 571 | + config.update(**kwargs) |
| 572 | + return config |
| 573 | + |
| 574 | + def check_over_configs(self, time_step=0, **config): |
| 575 | + kwargs = dict(self.forward_default_kwargs) |
| 576 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 577 | + sample = self.dummy_sample |
| 578 | + residual = 0.1 * sample |
| 579 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 580 | + |
| 581 | + for scheduler_class in self.scheduler_classes: |
| 582 | + scheduler_config = self.get_scheduler_config(**config) |
| 583 | + scheduler = scheduler_class(**scheduler_config) |
| 584 | + scheduler.set_timesteps(num_inference_steps) |
| 585 | + # copy over dummy past residuals |
| 586 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] |
| 587 | + |
| 588 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 589 | + scheduler.save_config(tmpdirname) |
| 590 | + new_scheduler = scheduler_class.from_config(tmpdirname) |
| 591 | + new_scheduler.set_timesteps(num_inference_steps) |
| 592 | + # copy over dummy past residuals |
| 593 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] |
| 594 | + |
| 595 | + output, new_output = sample, sample |
| 596 | + for t in range(time_step, time_step + scheduler.solver_order + 1): |
| 597 | + output = scheduler.step(residual, t, output, **kwargs).prev_sample |
| 598 | + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample |
| 599 | + |
| 600 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 601 | + |
| 602 | + def test_from_pretrained_save_pretrained(self): |
| 603 | + pass |
| 604 | + |
| 605 | + def check_over_forward(self, time_step=0, **forward_kwargs): |
| 606 | + kwargs = dict(self.forward_default_kwargs) |
| 607 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 608 | + sample = self.dummy_sample |
| 609 | + residual = 0.1 * sample |
| 610 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 611 | + |
| 612 | + for scheduler_class in self.scheduler_classes: |
| 613 | + scheduler_config = self.get_scheduler_config() |
| 614 | + scheduler = scheduler_class(**scheduler_config) |
| 615 | + scheduler.set_timesteps(num_inference_steps) |
| 616 | + |
| 617 | + # copy over dummy past residuals (must be after setting timesteps) |
| 618 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] |
| 619 | + |
| 620 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 621 | + scheduler.save_config(tmpdirname) |
| 622 | + new_scheduler = scheduler_class.from_config(tmpdirname) |
| 623 | + # copy over dummy past residuals |
| 624 | + new_scheduler.set_timesteps(num_inference_steps) |
| 625 | + |
| 626 | + # copy over dummy past residual (must be after setting timesteps) |
| 627 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.solver_order] |
| 628 | + |
| 629 | + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 630 | + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 631 | + |
| 632 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 633 | + |
| 634 | + def full_loop(self, **config): |
| 635 | + scheduler_class = self.scheduler_classes[0] |
| 636 | + scheduler_config = self.get_scheduler_config(**config) |
| 637 | + scheduler = scheduler_class(**scheduler_config) |
| 638 | + |
| 639 | + num_inference_steps = 10 |
| 640 | + model = self.dummy_model() |
| 641 | + sample = self.dummy_sample_deter |
| 642 | + scheduler.set_timesteps(num_inference_steps) |
| 643 | + |
| 644 | + for i, t in enumerate(scheduler.timesteps): |
| 645 | + residual = model(sample, t) |
| 646 | + sample = scheduler.step(residual, t, sample).prev_sample |
| 647 | + |
| 648 | + return sample |
| 649 | + |
| 650 | + def test_step_shape(self): |
| 651 | + kwargs = dict(self.forward_default_kwargs) |
| 652 | + |
| 653 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 654 | + |
| 655 | + for scheduler_class in self.scheduler_classes: |
| 656 | + scheduler_config = self.get_scheduler_config() |
| 657 | + scheduler = scheduler_class(**scheduler_config) |
| 658 | + |
| 659 | + sample = self.dummy_sample |
| 660 | + residual = 0.1 * sample |
| 661 | + |
| 662 | + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): |
| 663 | + scheduler.set_timesteps(num_inference_steps) |
| 664 | + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): |
| 665 | + kwargs["num_inference_steps"] = num_inference_steps |
| 666 | + |
| 667 | + # copy over dummy past residuals (must be done after set_timesteps) |
| 668 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 669 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.solver_order] |
| 670 | + |
| 671 | + time_step_0 = scheduler.timesteps[5] |
| 672 | + time_step_1 = scheduler.timesteps[6] |
| 673 | + |
| 674 | + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample |
| 675 | + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample |
| 676 | + |
| 677 | + self.assertEqual(output_0.shape, sample.shape) |
| 678 | + self.assertEqual(output_0.shape, output_1.shape) |
| 679 | + |
| 680 | + def test_timesteps(self): |
| 681 | + for timesteps in [25, 50, 100, 999, 1000]: |
| 682 | + self.check_over_configs(num_train_timesteps=timesteps) |
| 683 | + |
| 684 | + def test_thresholding(self): |
| 685 | + self.check_over_configs(thresholding=False) |
| 686 | + for order in [1, 2, 3]: |
| 687 | + for solver_type in ["dpm_solver", "taylor"]: |
| 688 | + for threshold in [0.5, 1.0, 2.0]: |
| 689 | + self.check_over_configs( |
| 690 | + thresholding=True, |
| 691 | + sample_max_value=threshold, |
| 692 | + predict_x0=True, |
| 693 | + solver_order=order, |
| 694 | + solver_type=solver_type, |
| 695 | + ) |
| 696 | + |
| 697 | + def test_solver_order_and_type(self): |
| 698 | + for solver_type in ["dpm_solver", "taylor"]: |
| 699 | + for order in [1, 2, 3]: |
| 700 | + for predict_x0 in [True, False]: |
| 701 | + self.check_over_configs(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) |
| 702 | + sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0) |
| 703 | + assert not torch.isnan(sample).any(), "Samples have nan numbers" |
| 704 | + |
| 705 | + def test_denoise_final(self): |
| 706 | + self.check_over_configs(denoise_final=True) |
| 707 | + self.check_over_configs(denoise_final=False) |
| 708 | + |
| 709 | + def test_inference_steps(self): |
| 710 | + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: |
| 711 | + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) |
| 712 | + |
| 713 | + def test_full_loop_no_noise(self): |
| 714 | + sample = self.full_loop() |
| 715 | + result_mean = torch.mean(torch.abs(sample)) |
| 716 | + |
| 717 | + assert abs(result_mean.item() - 0.3301) < 1e-3 |
| 718 | + |
| 719 | + |
552 | 720 | class PNDMSchedulerTest(SchedulerCommonTest): |
553 | 721 | scheduler_classes = (PNDMScheduler,) |
554 | 722 | forward_default_kwargs = (("num_inference_steps", 50),) |
|
0 commit comments