From 2e1278f134805b2972a6fb271eed4c1f37d55c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 12 Sep 2022 15:57:48 +0200 Subject: [PATCH 1/9] Unify offset configuration in DDIM and PNDM schedulers --- .../pipeline_stable_diffusion.py | 7 +--- .../pipeline_stable_diffusion_img2img.py | 9 +---- .../pipeline_stable_diffusion_inpaint.py | 9 +---- .../pipeline_stable_diffusion_onnx.py | 7 +--- src/diffusers/schedulers/scheduling_ddim.py | 27 +++++++++---- src/diffusers/schedulers/scheduling_pndm.py | 40 ++++++++++++++----- 6 files changed, 54 insertions(+), 45 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f02fa114a8e1..b15339d88553 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -217,12 +217,7 @@ def __call__( latents = latents.to(self.device) # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 475ceef4f002..e4caddc467ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -169,14 +169,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) if not isinstance(init_image, torch.FloatTensor): init_image = preprocess(init_image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 05ea84ae0326..27b8993d7c09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -192,14 +192,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # preprocess image init_image = preprocess_image(init_image).to(self.device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 7ff3ff22fc21..68ea0f476f33 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -100,12 +100,7 @@ def __call__( raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 894d63bf2df0..dbe1ccf525e9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,6 +17,7 @@ import math from typing import Optional, Tuple, Union +import warnings import numpy as np import torch @@ -78,7 +79,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha + product is fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, + to make the last step use step 0 for the previous alpha product. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -94,6 +100,7 @@ def __init__( timestep_values: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, tensor_format: str = "pt", ): if trained_betas is not None: @@ -112,10 +119,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this paratemer simply to one or - # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # setable values @@ -135,15 +138,25 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps: int, offset: int = 0): + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead." + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index b43d88bbab77..123d4b60b6bb 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -16,6 +16,7 @@ import math from typing import Optional, Tuple, Union +import warnings import numpy as np import torch @@ -73,10 +74,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): TODO - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha + product is fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, + to make the last step use step 0 for the previous alpha product. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ @@ -88,8 +96,10 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", skip_prk_steps: bool = False, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + tensor_format: str = "pt", ): if trained_betas is not None: self.betas = np.asarray(trained_betas) @@ -107,6 +117,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.one = np.array(1.0) # For now we only support F-PNDM, i.e. the runge-kutta method @@ -123,7 +135,6 @@ def __init__( # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self._offset = 0 self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None @@ -131,21 +142,30 @@ def __init__( self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead." + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps self._timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) - self._offset = offset - self._timesteps = np.array([t + self._offset for t in self._timesteps]) + self._timesteps = np.array(self._timesteps) + offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -322,7 +342,7 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -335,8 +355,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev From c727b46050ddbba0ebfdf7605ef5c704a25a94a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 12 Sep 2022 16:04:23 +0200 Subject: [PATCH 2/9] Format Add missing variables --- .../pipeline_stable_diffusion_img2img.py | 1 + .../pipeline_stable_diffusion_inpaint.py | 1 + src/diffusers/schedulers/scheduling_ddim.py | 12 ++++++------ src/diffusers/schedulers/scheduling_pndm.py | 12 ++++++------ 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e4caddc467ba..3d515fc4aa8a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -170,6 +170,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) + offset = 1 if not isinstance(init_image, torch.FloatTensor): init_image = preprocess(init_image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 27b8993d7c09..e6cd7fd3d045 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -193,6 +193,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) + offset = 1 # preprocess image init_image = preprocess_image(init_image).to(self.device) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index dbe1ccf525e9..9c7f4c0d6f4b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,8 +16,8 @@ # and https://github.com/hojonathanho/diffusion import math -from typing import Optional, Tuple, Union import warnings +from typing import Optional, Tuple, Union import numpy as np import torch @@ -79,12 +79,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - each diffusion step uses the value of alphas product at that step and at the previous one. - For the final step there is no previous alpha. When this option is `True` the previous alpha - product is fixed to `1`, otherwise it uses the value of alpha at step 0. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, - to make the last step use step 0 for the previous alpha product. + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 123d4b60b6bb..43eab5630e98 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,8 +15,8 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Optional, Tuple, Union import warnings +from typing import Optional, Tuple, Union import numpy as np import torch @@ -78,12 +78,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `True`): - each diffusion step uses the value of alphas product at that step and at the previous one. - For the final step there is no previous alpha. When this option is `True` the previous alpha - product is fixed to `1`, otherwise it uses the value of alpha at step 0. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, - to make the last step use step 0 for the previous alpha product. + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ From 00eafd1659157f89519e5121c6ef5ff1dc040f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 12 Sep 2022 16:38:07 +0200 Subject: [PATCH 3/9] Fix pipeline test --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 3d515fc4aa8a..be6c57d99779 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -170,7 +170,6 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - offset = 1 if not isinstance(init_image, torch.FloatTensor): init_image = preprocess(init_image) @@ -184,6 +183,7 @@ def __call__( init_latents = torch.cat([init_latents] * batch_size) # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e6cd7fd3d045..cab9e29ee8e1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -193,7 +193,6 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - offset = 1 # preprocess image init_image = preprocess_image(init_image).to(self.device) @@ -217,6 +216,7 @@ def __call__( raise ValueError("The mask and init_image should be the same size!") # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] From 63c6ed0d9cf0ab85721df97f946a6b1d7e650d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 13 Sep 2022 15:11:38 +0200 Subject: [PATCH 4/9] Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_ddim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 9c7f4c0d6f4b..ee6e99b15deb 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -152,7 +152,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): if "offset" in kwargs: warnings.warn( "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead." + " Please pass `steps_offset` to `__init__` instead.", + DeprecationWarning ) offset = kwargs["offset"] From 3f116116fea5d83617f264d70ab6739384de2a90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 13 Sep 2022 15:12:28 +0200 Subject: [PATCH 5/9] Default set_alpha_to_one to false --- src/diffusers/schedulers/scheduling_pndm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 43eab5630e98..0d6be4f0976a 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -77,7 +77,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. - set_alpha_to_one (`bool`, default `True`): + set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. @@ -97,7 +97,7 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, skip_prk_steps: bool = False, - set_alpha_to_one: bool = True, + set_alpha_to_one: bool = False, steps_offset: int = 0, tensor_format: str = "pt", ): From ea4c848a0447c14fc6bbad618a174cae2bada4f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 13 Sep 2022 15:15:51 +0200 Subject: [PATCH 6/9] Format --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index ee6e99b15deb..7936cec3bb0c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -153,7 +153,7 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): warnings.warn( "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." " Please pass `steps_offset` to `__init__` instead.", - DeprecationWarning + DeprecationWarning, ) offset = kwargs["offset"] From 916e5352c501806c772f9c013f7e361213e9e36b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 13 Sep 2022 16:08:47 +0200 Subject: [PATCH 7/9] Add tests --- src/diffusers/schedulers/scheduling_pndm.py | 4 +- tests/test_scheduler.py | 128 +++++++++++++++----- 2 files changed, 101 insertions(+), 31 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 0d6be4f0976a..1f08dd912115 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -250,7 +250,7 @@ def step_prk( ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: @@ -312,7 +312,7 @@ def step_plms( "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets.append(model_output) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 3c2e786fc1f4..0a8a6224a88a 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -356,10 +356,38 @@ def get_scheduler_config(self, **kwargs): config.update(**kwargs) return config + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps, eta = 10, 0.0 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + scheduler.set_timesteps(num_inference_steps) + + for t in scheduler.timesteps: + residual = model(sample, t) + sample = scheduler.step(residual, t, sample, eta).prev_sample + + return sample + def test_timesteps(self): for timesteps in [100, 500, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(5) + assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) @@ -397,26 +425,31 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) + sample = self.full_loop() - num_inference_steps, eta = 10, 0.0 + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - model = self.dummy_model() - sample = self.dummy_sample_deter + assert abs(result_sum.item() - 172.0067) < 1e-2 + assert abs(result_mean.item() - 0.223967) < 1e-3 - scheduler.set_timesteps(num_inference_steps) - for t in scheduler.timesteps: - residual = model(sample, t) + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - sample = scheduler.step(residual, t, sample, eta).prev_sample + assert abs(result_sum.item() - 149.8295) < 1e-2 + assert abs(result_mean.item() - 0.1951) < 1e-3 + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 172.0067) < 1e-2 - assert abs(result_mean.item() - 0.223967) < 1e-3 + assert abs(result_sum.item() - 149.0784) < 1e-2 + assert abs(result_mean.item() - 0.1941) < 1e-3 class PNDMSchedulerTest(SchedulerCommonTest): @@ -502,6 +535,26 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.prk_timesteps): + residual = model(sample, t) + sample = scheduler.step_prk(residual, t, sample).prev_sample + + for i, t in enumerate(scheduler.plms_timesteps): + residual = model(sample, t) + sample = scheduler.step_plms(residual, t, sample).prev_sample + + return sample + def test_pytorch_equal_numpy(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -605,8 +658,23 @@ def test_timesteps(self): for timesteps in [100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(10) + assert torch.equal( + scheduler.timesteps, + torch.tensor( + [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] + ), + ) + def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self): @@ -619,7 +687,7 @@ def test_time_indices(self): def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + self.check_over_forward(num_inference_steps=num_inference_steps) def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): @@ -630,28 +698,30 @@ def test_inference_plms_no_past_residuals(self): scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) + sample = self.full_loop() + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - scheduler.set_timesteps(num_inference_steps) + assert abs(result_sum.item() - 198.1318) < 1e-2 + assert abs(result_mean.item() - 0.2580) < 1e-3 - for i, t in enumerate(scheduler.prk_timesteps): - residual = model(sample, t) - sample = scheduler.step_prk(residual, i, sample).prev_sample + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - for i, t in enumerate(scheduler.plms_timesteps): - residual = model(sample, t) - sample = scheduler.step_plms(residual, i, sample).prev_sample + assert abs(result_sum.item() - 230.0399) < 1e-2 + assert abs(result_mean.item() - 0.2995) < 1e-3 + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 428.8788) < 1e-2 - assert abs(result_mean.item() - 0.5584) < 1e-3 + assert abs(result_sum.item() - 186.9482) < 1e-2 + assert abs(result_mean.item() - 0.2434) < 1e-3 class ScoreSdeVeSchedulerTest(unittest.TestCase): From 5afa7801f79633e62d477892de5354c4f48f948e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 13 Sep 2022 17:30:12 +0200 Subject: [PATCH 8/9] Format --- src/diffusers/schedulers/scheduling_ddim.py | 4 ++-- src/diffusers/schedulers/scheduling_pndm.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 07a163175c81..242cf4d1f3d3 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -84,8 +84,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, - as done in stable diffusion. + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ba428cec3c5b..6e592c884d73 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -84,8 +84,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, - as done in stable diffusion. + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ From babee7050ad36874dddd6d92f989917c6bc6f482 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Sep 2022 19:04:22 +0000 Subject: [PATCH 9/9] add deprecation warning --- .../pipeline_stable_diffusion.py | 16 ++++++++++++++ .../pipeline_stable_diffusion_img2img.py | 17 +++++++++++++++ .../pipeline_stable_diffusion_inpaint.py | 21 +++++++++++++++++-- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3ae50fc19f76..9f1211b43013 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -6,6 +6,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -53,6 +54,21 @@ def __init__( ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e442c94a7ac4..e7adb4d1a33b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -7,6 +8,7 @@ import PIL from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -64,6 +66,21 @@ def __init__( ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1401b369b218..b9ad36f1a2bf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -8,6 +9,7 @@ from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, PNDMScheduler @@ -83,6 +85,21 @@ def __init__( super().__init__() scheduler = scheduler.set_format("pt") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -198,7 +215,7 @@ def __call__( # preprocess image if not isinstance(init_image, torch.FloatTensor): init_image = preprocess_image(init_image) - init_image.to(self.device) + init_image = init_image.to(self.device) # encode the init image into latents and scale the latents init_latent_dist = self.vae.encode(init_image).latent_dist @@ -213,7 +230,7 @@ def __call__( # preprocess mask if not isinstance(mask_image, torch.FloatTensor): mask_image = preprocess_mask(mask_image) - mask_image.to(self.device) + mask_image = mask_image.to(self.device) mask = torch.cat([mask_image] * batch_size) # check sizes