-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Unify offset configuration in DDIM and PNDM schedulers #479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e1278f
c727b46
00eafd1
63c6ed0
3f11611
ea4c848
916e535
89b3eec
5afa780
19568b5
aa25955
babee70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
|
|
||
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
|
|
||
| from ...configuration_utils import FrozenDict | ||
| from ...models import AutoencoderKL, UNet2DConditionModel | ||
| from ...pipeline_utils import DiffusionPipeline | ||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||
|
|
@@ -53,6 +54,21 @@ def __init__( | |
| ): | ||
| super().__init__() | ||
| scheduler = scheduler.set_format("pt") | ||
|
|
||
| if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | ||
| warnings.warn( | ||
| f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | ||
| f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " | ||
| "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | ||
| " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | ||
| " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | ||
| " file", | ||
| DeprecationWarning, | ||
| ) | ||
| new_config = dict(scheduler.config) | ||
| new_config["steps_offset"] = 1 | ||
| scheduler._internal_dict = FrozenDict(new_config) | ||
|
|
||
| self.register_modules( | ||
| vae=vae, | ||
| text_encoder=text_encoder, | ||
|
|
@@ -217,12 +233,7 @@ def __call__( | |
| latents = latents.to(self.device) | ||
|
|
||
| # set timesteps | ||
| accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | ||
| extra_set_kwargs = {} | ||
| if accepts_offset: | ||
| extra_set_kwargs["offset"] = 1 | ||
|
|
||
| self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
Comment on lines
235
to
+236
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we try to override the configuration here, or are you going to update the configuration in the relevant hf repositories? |
||
|
|
||
| # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import inspect | ||
| import warnings | ||
| from typing import List, Optional, Union | ||
|
|
||
| import numpy as np | ||
|
|
@@ -7,6 +8,7 @@ | |
| import PIL | ||
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
|
|
||
| from ...configuration_utils import FrozenDict | ||
| from ...models import AutoencoderKL, UNet2DConditionModel | ||
| from ...pipeline_utils import DiffusionPipeline | ||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||
|
|
@@ -64,6 +66,21 @@ def __init__( | |
| ): | ||
| super().__init__() | ||
| scheduler = scheduler.set_format("pt") | ||
|
|
||
| if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | ||
| warnings.warn( | ||
| f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | ||
| f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " | ||
| "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | ||
| " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | ||
| " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | ||
| " file", | ||
| DeprecationWarning, | ||
| ) | ||
| new_config = dict(scheduler.config) | ||
| new_config["steps_offset"] = 1 | ||
| scheduler._internal_dict = FrozenDict(new_config) | ||
|
|
||
| self.register_modules( | ||
| vae=vae, | ||
| text_encoder=text_encoder, | ||
|
|
@@ -169,14 +186,7 @@ def __call__( | |
| raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") | ||
|
|
||
| # set timesteps | ||
| accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | ||
| extra_set_kwargs = {} | ||
| offset = 0 | ||
| if accepts_offset: | ||
| offset = 1 | ||
| extra_set_kwargs["offset"] = 1 | ||
|
|
||
| self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
|
||
| if isinstance(init_image, PIL.Image.Image): | ||
| init_image = preprocess(init_image) | ||
|
|
@@ -190,6 +200,7 @@ def __call__( | |
| init_latents = torch.cat([init_latents] * batch_size) | ||
|
|
||
| # get the original timestep using init_timestep | ||
| offset = self.scheduler.config.get("steps_offset", 0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
| init_timestep = int(num_inference_steps * strength) + offset | ||
| init_timestep = min(init_timestep, num_inference_steps) | ||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Optional, Tuple, Union | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): | |||||||||||||||||||||||||||||||||||||||||||||||
| `linear`, `scaled_linear`, or `squaredcos_cap_v2`. | ||||||||||||||||||||||||||||||||||||||||||||||||
| trained_betas (`np.ndarray`, optional): | ||||||||||||||||||||||||||||||||||||||||||||||||
| option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays | ||||||||||||||||||||||||||||||||||||||||||||||||
| skip_prk_steps (`bool`): | ||||||||||||||||||||||||||||||||||||||||||||||||
| allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required | ||||||||||||||||||||||||||||||||||||||||||||||||
| before plms steps; defaults to `False`. | ||||||||||||||||||||||||||||||||||||||||||||||||
| set_alpha_to_one (`bool`, default `False`): | ||||||||||||||||||||||||||||||||||||||||||||||||
| each diffusion step uses the value of alphas product at that step and at the previous one. For the final | ||||||||||||||||||||||||||||||||||||||||||||||||
| step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, | ||||||||||||||||||||||||||||||||||||||||||||||||
| otherwise it uses the value of alpha at step 0. | ||||||||||||||||||||||||||||||||||||||||||||||||
| steps_offset (`int`, default `0`): | ||||||||||||||||||||||||||||||||||||||||||||||||
| an offset added to the inference steps. You can use a combination of `offset=1` and | ||||||||||||||||||||||||||||||||||||||||||||||||
| `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in | ||||||||||||||||||||||||||||||||||||||||||||||||
| stable diffusion. | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -89,8 +98,10 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||
| beta_end: float = 0.02, | ||||||||||||||||||||||||||||||||||||||||||||||||
| beta_schedule: str = "linear", | ||||||||||||||||||||||||||||||||||||||||||||||||
| trained_betas: Optional[np.ndarray] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor_format: str = "pt", | ||||||||||||||||||||||||||||||||||||||||||||||||
| skip_prk_steps: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| set_alpha_to_one: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||
| steps_offset: int = 0, | ||||||||||||||||||||||||||||||||||||||||||||||||
| tensor_format: str = "pt", | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if trained_betas is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.betas = np.asarray(trained_betas) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -108,6 +119,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||
| self.alphas = 1.0 - self.betas | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.alphas_cumprod = np.cumprod(self.alphas, axis=0) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # For now we only support F-PNDM, i.e. the runge-kutta method | ||||||||||||||||||||||||||||||||||||||||||||||||
| # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf | ||||||||||||||||||||||||||||||||||||||||||||||||
| # mainly at formula (9), (12), (13) and the Algorithm 2. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -122,31 +135,38 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||
| # setable values | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_inference_steps = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._offset = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.prk_timesteps = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.plms_timesteps = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.timesteps = None | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.tensor_format = tensor_format | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.set_format(tensor_format=tensor_format) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: | ||||||||||||||||||||||||||||||||||||||||||||||||
| def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
| Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||
| num_inference_steps (`int`): | ||||||||||||||||||||||||||||||||||||||||||||||||
| the number of diffusion steps used when generating samples with a pre-trained model. | ||||||||||||||||||||||||||||||||||||||||||||||||
| offset (`int`): | ||||||||||||||||||||||||||||||||||||||||||||||||
| optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| offset = self.config.steps_offset | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if "offset" in kwargs: | ||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||||||||||||||||||||||||||||||
| "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||||||||||||||||||||||||||||||||||||||||||||||||
| " Please pass `steps_offset` to `__init__` instead." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| offset = kwargs["offset"] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_inference_steps = num_inference_steps | ||||||||||||||||||||||||||||||||||||||||||||||||
| step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||||||||||||||||||||||||||||||||||||||||||||||||
| # creates integer timesteps by multiplying by ratio | ||||||||||||||||||||||||||||||||||||||||||||||||
| # casting to int to avoid issues when num_inference_step is power of 3 | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist() | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._offset = offset | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._timesteps = np.array([t + self._offset for t in self._timesteps]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() | ||||||||||||||||||||||||||||||||||||||||||||||||
| self._timesteps += offset | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if self.config.skip_prk_steps: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # for some models like stable diffusion the prk steps can/should be skipped to | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -231,7 +251,7 @@ def step_prk( | |||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 | ||||||||||||||||||||||||||||||||||||||||||||||||
| prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| prev_timestep = timestep - diff_to_prev | ||||||||||||||||||||||||||||||||||||||||||||||||
| timestep = self.prk_timesteps[self.counter // 4 * 4] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if self.counter % 4 == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -293,7 +313,7 @@ def step_plms( | |||||||||||||||||||||||||||||||||||||||||||||||
| "for more information." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if self.counter != 1: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.ets.append(model_output) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -323,7 +343,7 @@ def step_plms( | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return SchedulerOutput(prev_sample=prev_sample) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): | ||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf | ||||||||||||||||||||||||||||||||||||||||||||||||
| # this function computes x_(t−δ) using the formula of (9) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Note that x_t needs to be added to both sides of the equation | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -336,8 +356,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): | |||||||||||||||||||||||||||||||||||||||||||||||
| # sample -> x_t | ||||||||||||||||||||||||||||||||||||||||||||||||
| # model_output -> e_θ(x_t, t) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # prev_sample -> x_(t−δ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] | ||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
-339
to
-340
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten this change breaks this test: diffusers/tests/test_scheduler.py Lines 633 to 655 in f4781a0
However, I think the test must be wrong, because there we use the default
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I think the test breaks because
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually no, adding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept the change and adjusted the test value. |
||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_prod_t = self.alphas_cumprod[timestep] | ||||||||||||||||||||||||||||||||||||||||||||||||
| alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | ||||||||||||||||||||||||||||||||||||||||||||||||
| beta_prod_t = 1 - alpha_prod_t | ||||||||||||||||||||||||||||||||||||||||||||||||
| beta_prod_t_prev = 1 - alpha_prod_t_prev | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonatanklosko - Currently the slow tests are broken because we haven't (and can't yet) update the configs online. So for now I think we need to do some deprecation warning here - ok for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!
Our of curiosity, is there an issue with updating the configs now? I assume an additional property in the config doesn't hurt even if ignored by previous versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most users are on the
0.3.0version therefore will not make use of this code but will nevertheless use the newest config file. Given the high usage of stable diffusion: https://huggingface.co/CompVis/stable-diffusion-v1-4 (>500k downloads per month), I think it would be better to not burden users with an extra config parameter that should throw a warning in0.3.0.So, it would indeed not break anything, but an unused config parameter should throw a warning (haven't checked though whether our code actually works correctly here), which will probably scare users.
Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah totally, I didn't realise the extra option would cause a warning :)