-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix schedulers zero SNR and rescale classifier free guidance #3664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
f2d1ec6
e70eac2
fe3cb42
5670826
db5ff82
6c62ff0
e4aef4d
9cda72c
b42dd28
a2fbd41
04f9f0a
ca95658
147ca8e
65532cd
03282f8
eb2e258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,22 @@ | |
| """ | ||
|
|
||
|
|
||
| def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0): | ||
| """ | ||
| Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
| """ | ||
| # std_text = torch.std(noise_pred_text) | ||
| # std_pred = torch.std(noise_pred) | ||
|
||
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | ||
| std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True) | ||
| # rescale the results from guidance (fixes overexposure) | ||
| noise_pred_rescaled = noise_pred * (std_text / std_pred) | ||
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
| noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred | ||
| return noise_pred | ||
|
|
||
|
|
||
| class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): | ||
| r""" | ||
| Pipeline for text-to-image generation using Stable Diffusion. | ||
|
|
@@ -560,6 +576,7 @@ def __call__( | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
| callback_steps: int = 1, | ||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| guidance_rescale: float = 0.0, | ||
| ): | ||
| r""" | ||
| Function invoked when calling the pipeline for generation. | ||
|
|
@@ -620,6 +637,11 @@ def __call__( | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||
| `self.processor` in | ||
| [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
| guidance_rescale (`float`, *optional*, defaults to 0.7): | ||
| Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are | ||
| Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of | ||
| [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). | ||
| Guidance rescale factor should fix overexposure when using zero terminal SNR. | ||
|
|
||
| Examples: | ||
|
|
||
|
|
@@ -706,6 +728,10 @@ def __call__( | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
|
||
| if do_classifier_free_guidance and guidance_rescale > 0.0: | ||
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | ||
|
|
||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,42 @@ def alpha_bar(time_step): | |
| return torch.tensor(betas, dtype=torch.float32) | ||
|
|
||
|
|
||
| def rescale_zero_terminal_snr(betas): | ||
| """ | ||
| Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) | ||
|
|
||
|
|
||
| Args: | ||
| betas (`torch.FloatTensor`): | ||
| the betas that the scheduler is being initialized with. | ||
|
|
||
| Returns: | ||
| `torch.FloatTensor`: rescaled betas with zero terminal SNR | ||
| """ | ||
| # Convert betas to alphas_bar_sqrt | ||
| alphas = 1.0 - betas | ||
| alphas_cumprod = torch.cumprod(alphas, dim=0) | ||
| alphas_bar_sqrt = alphas_cumprod.sqrt() | ||
|
|
||
| # Store old values. | ||
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | ||
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | ||
|
|
||
| # Shift so the last timestep is zero. | ||
| alphas_bar_sqrt -= alphas_bar_sqrt_T | ||
|
|
||
| # Scale so the first timestep is back to the old value. | ||
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | ||
|
|
||
| # Convert alphas_bar_sqrt to betas | ||
| alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | ||
| alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | ||
| alphas = torch.cat([alphas_bar[0:1], alphas]) | ||
| betas = 1 - alphas | ||
|
|
||
| return betas | ||
|
|
||
|
|
||
| class DDIMScheduler(SchedulerMixin, ConfigMixin): | ||
| """ | ||
| Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising | ||
|
|
@@ -122,6 +158,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): | |
| (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. | ||
| sample_max_value (`float`, default `1.0`): | ||
| the threshold value for dynamic thresholding. Valid only when `thresholding=True`. | ||
| timestep_scaling (`str`, default `"leading"`): | ||
| The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample | ||
| Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. | ||
| rescale_betas_zero_snr (`bool`, default `False`): | ||
| whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). | ||
| This can enable the model to generate very bright and dark samples instead of limiting it to samples with | ||
| medium brightness. | ||
|
||
| """ | ||
|
|
||
| _compatibles = [e.name for e in KarrasDiffusionSchedulers] | ||
|
|
@@ -143,6 +186,8 @@ def __init__( | |
| dynamic_thresholding_ratio: float = 0.995, | ||
| clip_sample_range: float = 1.0, | ||
| sample_max_value: float = 1.0, | ||
| timestep_scaling: str = "leading", | ||
|
||
| rescale_betas_zero_snr: bool = False, | ||
| ): | ||
| if trained_betas is not None: | ||
| self.betas = torch.tensor(trained_betas, dtype=torch.float32) | ||
|
|
@@ -159,6 +204,10 @@ def __init__( | |
| else: | ||
| raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
|
|
||
| # Rescale for zero SNR | ||
| if rescale_betas_zero_snr: | ||
| self.betas = rescale_zero_terminal_snr(self.betas) | ||
|
|
||
| self.alphas = 1.0 - self.betas | ||
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
|
|
||
|
|
@@ -254,9 +303,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic | |
| step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||
|
||
| # creates integer timesteps by multiplying by ratio | ||
| # casting to int to avoid issues when num_inference_step is power of 3 | ||
| timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | ||
|
|
||
| # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 | ||
| if self.config.timestep_scaling == "leading": | ||
| timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | ||
| timesteps += self.config.steps_offset | ||
| elif self.config.timestep_scaling == "trailing": | ||
| timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy() | ||
patrickvonplaten marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| timesteps -= 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy we could remove, think keeping it in numpy as mainly just a style choice |
||
| else: | ||
| raise ValueError( | ||
| f"{self.config.timestep_scaling} is not supported. Please make sure to choose one of 'leading' or 'trailing'." | ||
| ) | ||
|
|
||
| self.timesteps = torch.from_numpy(timesteps).to(device) | ||
| self.timesteps += self.config.steps_offset | ||
|
|
||
| def step( | ||
| self, | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is quite a weird design. Let's move classifier free guidance calculation also into this function?
Because currently argument
noise_predis confusing. Does it mean unconditional prediction or prediction after cfg (latter is correct but the name kinda imply the prior).So let's move cfg also into this function.
We wrote equation 15,16 for simpler understanding, but the computation can be fused for more efficient computation.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ok, I'd like to keep them separated to better seperate classifier free guidance (which is used essentially by every diffusion pipeline) from rescaling which is a bit newer and IMO.
Happy to rename
noise_predtopred_cfg