|
27 | 27 | from diffusers import ( |
28 | 28 | DDIMScheduler, |
29 | 29 | DDPMScheduler, |
| 30 | + DEISMultistepScheduler, |
30 | 31 | DPMSolverMultistepScheduler, |
31 | 32 | DPMSolverSinglestepScheduler, |
32 | 33 | EulerAncestralDiscreteScheduler, |
@@ -2505,6 +2506,207 @@ def test_full_loop_device(self): |
2505 | 2506 | assert abs(result_mean.item() - 0.0266) < 1e-3 |
2506 | 2507 |
|
2507 | 2508 |
|
| 2509 | +class DEISMultistepSchedulerTest(SchedulerCommonTest): |
| 2510 | + scheduler_classes = (DEISMultistepScheduler,) |
| 2511 | + forward_default_kwargs = (("num_inference_steps", 25),) |
| 2512 | + |
| 2513 | + def get_scheduler_config(self, **kwargs): |
| 2514 | + config = { |
| 2515 | + "num_train_timesteps": 1000, |
| 2516 | + "beta_start": 0.0001, |
| 2517 | + "beta_end": 0.02, |
| 2518 | + "beta_schedule": "linear", |
| 2519 | + "solver_order": 2, |
| 2520 | + } |
| 2521 | + |
| 2522 | + config.update(**kwargs) |
| 2523 | + return config |
| 2524 | + |
| 2525 | + def check_over_configs(self, time_step=0, **config): |
| 2526 | + kwargs = dict(self.forward_default_kwargs) |
| 2527 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 2528 | + sample = self.dummy_sample |
| 2529 | + residual = 0.1 * sample |
| 2530 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 2531 | + |
| 2532 | + for scheduler_class in self.scheduler_classes: |
| 2533 | + scheduler_config = self.get_scheduler_config(**config) |
| 2534 | + scheduler = scheduler_class(**scheduler_config) |
| 2535 | + scheduler.set_timesteps(num_inference_steps) |
| 2536 | + # copy over dummy past residuals |
| 2537 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] |
| 2538 | + |
| 2539 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2540 | + scheduler.save_config(tmpdirname) |
| 2541 | + new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
| 2542 | + new_scheduler.set_timesteps(num_inference_steps) |
| 2543 | + # copy over dummy past residuals |
| 2544 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] |
| 2545 | + |
| 2546 | + output, new_output = sample, sample |
| 2547 | + for t in range(time_step, time_step + scheduler.config.solver_order + 1): |
| 2548 | + output = scheduler.step(residual, t, output, **kwargs).prev_sample |
| 2549 | + new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample |
| 2550 | + |
| 2551 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 2552 | + |
| 2553 | + def test_from_save_pretrained(self): |
| 2554 | + pass |
| 2555 | + |
| 2556 | + def check_over_forward(self, time_step=0, **forward_kwargs): |
| 2557 | + kwargs = dict(self.forward_default_kwargs) |
| 2558 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 2559 | + sample = self.dummy_sample |
| 2560 | + residual = 0.1 * sample |
| 2561 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 2562 | + |
| 2563 | + for scheduler_class in self.scheduler_classes: |
| 2564 | + scheduler_config = self.get_scheduler_config() |
| 2565 | + scheduler = scheduler_class(**scheduler_config) |
| 2566 | + scheduler.set_timesteps(num_inference_steps) |
| 2567 | + |
| 2568 | + # copy over dummy past residuals (must be after setting timesteps) |
| 2569 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] |
| 2570 | + |
| 2571 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 2572 | + scheduler.save_config(tmpdirname) |
| 2573 | + new_scheduler = scheduler_class.from_pretrained(tmpdirname) |
| 2574 | + # copy over dummy past residuals |
| 2575 | + new_scheduler.set_timesteps(num_inference_steps) |
| 2576 | + |
| 2577 | + # copy over dummy past residual (must be after setting timesteps) |
| 2578 | + new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] |
| 2579 | + |
| 2580 | + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 2581 | + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample |
| 2582 | + |
| 2583 | + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" |
| 2584 | + |
| 2585 | + def full_loop(self, **config): |
| 2586 | + scheduler_class = self.scheduler_classes[0] |
| 2587 | + scheduler_config = self.get_scheduler_config(**config) |
| 2588 | + scheduler = scheduler_class(**scheduler_config) |
| 2589 | + |
| 2590 | + num_inference_steps = 10 |
| 2591 | + model = self.dummy_model() |
| 2592 | + sample = self.dummy_sample_deter |
| 2593 | + scheduler.set_timesteps(num_inference_steps) |
| 2594 | + |
| 2595 | + for i, t in enumerate(scheduler.timesteps): |
| 2596 | + residual = model(sample, t) |
| 2597 | + sample = scheduler.step(residual, t, sample).prev_sample |
| 2598 | + |
| 2599 | + return sample |
| 2600 | + |
| 2601 | + def test_step_shape(self): |
| 2602 | + kwargs = dict(self.forward_default_kwargs) |
| 2603 | + |
| 2604 | + num_inference_steps = kwargs.pop("num_inference_steps", None) |
| 2605 | + |
| 2606 | + for scheduler_class in self.scheduler_classes: |
| 2607 | + scheduler_config = self.get_scheduler_config() |
| 2608 | + scheduler = scheduler_class(**scheduler_config) |
| 2609 | + |
| 2610 | + sample = self.dummy_sample |
| 2611 | + residual = 0.1 * sample |
| 2612 | + |
| 2613 | + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): |
| 2614 | + scheduler.set_timesteps(num_inference_steps) |
| 2615 | + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): |
| 2616 | + kwargs["num_inference_steps"] = num_inference_steps |
| 2617 | + |
| 2618 | + # copy over dummy past residuals (must be done after set_timesteps) |
| 2619 | + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10] |
| 2620 | + scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order] |
| 2621 | + |
| 2622 | + time_step_0 = scheduler.timesteps[5] |
| 2623 | + time_step_1 = scheduler.timesteps[6] |
| 2624 | + |
| 2625 | + output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample |
| 2626 | + output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample |
| 2627 | + |
| 2628 | + self.assertEqual(output_0.shape, sample.shape) |
| 2629 | + self.assertEqual(output_0.shape, output_1.shape) |
| 2630 | + |
| 2631 | + def test_timesteps(self): |
| 2632 | + for timesteps in [25, 50, 100, 999, 1000]: |
| 2633 | + self.check_over_configs(num_train_timesteps=timesteps) |
| 2634 | + |
| 2635 | + def test_thresholding(self): |
| 2636 | + self.check_over_configs(thresholding=False) |
| 2637 | + for order in [1, 2, 3]: |
| 2638 | + for solver_type in ["logrho"]: |
| 2639 | + for threshold in [0.5, 1.0, 2.0]: |
| 2640 | + for prediction_type in ["epsilon", "sample"]: |
| 2641 | + self.check_over_configs( |
| 2642 | + thresholding=True, |
| 2643 | + prediction_type=prediction_type, |
| 2644 | + sample_max_value=threshold, |
| 2645 | + algorithm_type="deis", |
| 2646 | + solver_order=order, |
| 2647 | + solver_type=solver_type, |
| 2648 | + ) |
| 2649 | + |
| 2650 | + def test_prediction_type(self): |
| 2651 | + for prediction_type in ["epsilon", "v_prediction"]: |
| 2652 | + self.check_over_configs(prediction_type=prediction_type) |
| 2653 | + |
| 2654 | + def test_solver_order_and_type(self): |
| 2655 | + for algorithm_type in ["deis"]: |
| 2656 | + for solver_type in ["logrho"]: |
| 2657 | + for order in [1, 2, 3]: |
| 2658 | + for prediction_type in ["epsilon", "sample"]: |
| 2659 | + self.check_over_configs( |
| 2660 | + solver_order=order, |
| 2661 | + solver_type=solver_type, |
| 2662 | + prediction_type=prediction_type, |
| 2663 | + algorithm_type=algorithm_type, |
| 2664 | + ) |
| 2665 | + sample = self.full_loop( |
| 2666 | + solver_order=order, |
| 2667 | + solver_type=solver_type, |
| 2668 | + prediction_type=prediction_type, |
| 2669 | + algorithm_type=algorithm_type, |
| 2670 | + ) |
| 2671 | + assert not torch.isnan(sample).any(), "Samples have nan numbers" |
| 2672 | + |
| 2673 | + def test_lower_order_final(self): |
| 2674 | + self.check_over_configs(lower_order_final=True) |
| 2675 | + self.check_over_configs(lower_order_final=False) |
| 2676 | + |
| 2677 | + def test_inference_steps(self): |
| 2678 | + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: |
| 2679 | + self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) |
| 2680 | + |
| 2681 | + def test_full_loop_no_noise(self): |
| 2682 | + sample = self.full_loop() |
| 2683 | + result_mean = torch.mean(torch.abs(sample)) |
| 2684 | + |
| 2685 | + assert abs(result_mean.item() - 0.23916) < 1e-3 |
| 2686 | + |
| 2687 | + def test_full_loop_with_v_prediction(self): |
| 2688 | + sample = self.full_loop(prediction_type="v_prediction") |
| 2689 | + result_mean = torch.mean(torch.abs(sample)) |
| 2690 | + |
| 2691 | + assert abs(result_mean.item() - 0.091) < 1e-3 |
| 2692 | + |
| 2693 | + def test_fp16_support(self): |
| 2694 | + scheduler_class = self.scheduler_classes[0] |
| 2695 | + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) |
| 2696 | + scheduler = scheduler_class(**scheduler_config) |
| 2697 | + |
| 2698 | + num_inference_steps = 10 |
| 2699 | + model = self.dummy_model() |
| 2700 | + sample = self.dummy_sample_deter.half() |
| 2701 | + scheduler.set_timesteps(num_inference_steps) |
| 2702 | + |
| 2703 | + for i, t in enumerate(scheduler.timesteps): |
| 2704 | + residual = model(sample, t) |
| 2705 | + sample = scheduler.step(residual, t, sample).prev_sample |
| 2706 | + |
| 2707 | + assert sample.dtype == torch.float16 |
| 2708 | + |
| 2709 | + |
2508 | 2710 | class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): |
2509 | 2711 | scheduler_classes = (KDPM2AncestralDiscreteScheduler,) |
2510 | 2712 | num_inference_steps = 10 |
|
0 commit comments