From eacebf13c21ded8f29c280ab4b37e5c187bb7385 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 15:01:50 +0100 Subject: [PATCH 1/6] support v prediction in other schedulers --- .../scheduling_euler_ancestral_discrete.py | 12 +++++++++++- src/diffusers/schedulers/scheduling_lms_discrete.py | 11 ++++++++++- src/diffusers/schedulers/scheduling_pndm.py | 8 ++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index fe8a36c43f51..7011e1f30a4e 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -78,6 +78,7 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -202,7 +203,16 @@ def step( sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + sigma_from = self.sigmas[step_index] sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 4c28db591a62..044eb70c213c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -78,6 +78,7 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -215,7 +216,15 @@ def step( sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index bb3e098f7e42..ddfd1338a4e5 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -102,6 +102,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, + prediction_type: str = "epsilon", steps_offset: int = 0, ): if trained_betas is not None: @@ -368,6 +369,13 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = From 0af234b84351d0369d02497d5438719e4a153fa7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 16:03:50 +0100 Subject: [PATCH 2/6] v heun --- src/diffusers/schedulers/scheduling_heun.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py index d21591b3df21..27a54f645e4e 100644 --- a/src/diffusers/schedulers/scheduling_heun.py +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -54,6 +54,7 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -184,7 +185,15 @@ def step( sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma_hat * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) if self.state_in_first_order: # 2. Convert to an ODE derivative From abcf99ca46ba51e9bdc0c34d07187d124deaaef6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 17:36:38 +0100 Subject: [PATCH 3/6] add tests for v pred --- tests/test_scheduler.py | 173 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 621fbfe1253b..7a923d8b9639 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -42,6 +42,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import deprecate, torch_device from diffusers.utils.testing_utils import CaptureLogger +from parameterized import parameterized torch.backends.cuda.matmul.allow_tf32 = False @@ -768,6 +769,10 @@ def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) @@ -805,6 +810,15 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 172.0067) < 1e-2 assert abs(result_mean.item() - 0.223967) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 52.5302) < 1e-2 + assert abs(result_mean.item() - 0.0684) < 1e-3 + def test_full_loop_with_set_alpha_to_one(self): # We specify different beta, so that the first alpha is 0.99 sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) @@ -971,6 +985,10 @@ def test_thresholding(self): solver_type=solver_type, ) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_solver_order_and_type(self): for algorithm_type in ["dpmsolver", "dpmsolver++"]: for solver_type in ["midpoint", "heun"]: @@ -1004,6 +1022,12 @@ def test_full_loop_no_noise(self): assert abs(result_mean.item() - 0.3301) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2251) < 1e-3 + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) @@ -1184,6 +1208,10 @@ def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_time_indices(self): for t in [1, 5, 10]: self.check_over_forward(time_step=t) @@ -1225,6 +1253,14 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 198.1318) < 1e-2 assert abs(result_mean.item() - 0.2580) < 1e-3 + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 67.3986) < 1e-2 + assert abs(result_mean.item() - 0.0878) < 1e-3 + def test_full_loop_with_set_alpha_to_one(self): # We specify different beta, so that the first alpha is 0.99 sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) @@ -1453,6 +1489,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_time_indices(self): for t in [0, 500, 800]: self.check_over_forward(time_step=t) @@ -1481,6 +1521,30 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 0.0017) < 1e-2 + assert abs(result_mean.item() - 2.2676e-06) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1534,6 +1598,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1565,6 +1633,37 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 10.0807) < 1e-2 assert abs(result_mean.item() - 0.0131) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 0.0002) < 1e-2 + assert abs(result_mean.item() - 2.2676e-06) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1624,6 +1723,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1660,6 +1763,42 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 144.8084) < 1e-2 assert abs(result_mean.item() - 0.18855) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample, generator=generator) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 108.4439) < 1e-2 + assert abs(result_mean.item() - 0.1412) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 144.8084) < 1e-2 + assert abs(result_mean.item() - 0.18855) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1932,6 +2071,10 @@ def test_schedules(self): for schedule in ["linear", "scaled_linear"]: self.check_over_configs(beta_schedule=schedule) + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1962,6 +2105,36 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 0.1233) < 1e-2 assert abs(result_mean.item() - 0.0002) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 4.6934e-07) < 1e-2 + assert abs(result_mean.item() - 6.1112e-10) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() From de82ba4557b15d28bf1055b948118346d0eb4a42 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 17:54:58 +0100 Subject: [PATCH 4/6] fix tests --- tests/test_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 7a923d8b9639..8ba717ca66b7 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1796,7 +1796,7 @@ def test_full_loop_with_v_prediction(self): assert abs(result_mean.item() - 0.1412) < 1e-3 else: # CUDA - assert abs(result_sum.item() - 144.8084) < 1e-2 + assert abs(result_sum.item() - 102.5807) < 1e-2 assert abs(result_mean.item() - 0.18855) < 1e-3 def test_full_loop_device(self): @@ -2132,7 +2132,7 @@ def test_full_loop_with_v_prediction(self): assert abs(result_mean.item() - 6.1112e-10) < 1e-3 else: # CUDA - assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_sum.item() - 4.693428650170972e-07) < 1e-2 assert abs(result_mean.item() - 0.0002) < 1e-3 def test_full_loop_device(self): From 9dc7192a642e9006db1420bc2411ca92bec423bc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 17:56:11 +0100 Subject: [PATCH 5/6] fix test euler a --- tests/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 8ba717ca66b7..386098634468 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1797,7 +1797,7 @@ def test_full_loop_with_v_prediction(self): else: # CUDA assert abs(result_sum.item() - 102.5807) < 1e-2 - assert abs(result_mean.item() - 0.18855) < 1e-3 + assert abs(result_mean.item() - 0.1335) < 1e-3 def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] From bf56d466e109daf976163a2ad00fad8d04c9822e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 1 Dec 2022 18:05:00 +0100 Subject: [PATCH 6/6] v ddpm --- src/diffusers/schedulers/scheduling_ddpm.py | 6 ++-- tests/test_scheduler.py | 34 +++++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 62a13c33cd89..dcf899935deb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -280,10 +280,12 @@ def step( pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " - " for the DDPMScheduler." + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." ) # 3. Clip "predicted x_0" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 386098634468..5d885fef6b12 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -42,7 +42,6 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import deprecate, torch_device from diffusers.utils.testing_utils import CaptureLogger -from parameterized import parameterized torch.backends.cuda.matmul.allow_tf32 = False @@ -636,7 +635,7 @@ def test_clip_sample(self): self.check_over_configs(clip_sample=clip_sample) def test_prediction_type(self): - for prediction_type in ["epsilon", "sample"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: self.check_over_configs(prediction_type=prediction_type) def test_deprecated_predict_epsilon(self): @@ -712,6 +711,37 @@ def test_full_loop_no_noise(self): assert abs(result_sum.item() - 258.9070) < 1e-2 assert abs(result_mean.item() - 0.3374) < 1e-3 + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + generator = torch.manual_seed(0) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 201.9864) < 1e-2 + assert abs(result_mean.item() - 0.2630) < 1e-3 + class DDIMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDIMScheduler,)