Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_pred
def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0):
Copy link
Contributor

@PeterL1n PeterL1n Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is quite a weird design. Let's move classifier free guidance calculation also into this function?

Because currently argument noise_pred is confusing. Does it mean unconditional prediction or prediction after cfg (latter is correct but the name kinda imply the prior).

So let's move cfg also into this function.

def classifier_free_guidance(pred_pos, pred_neg, guidance_weight, guidance_rescale=0):
    # Apply classifier-free guidance.
    pred_cfg = pred_neg + guidance_weight * (pred_pos - pred_neg)

    # Apply guidance rescale. From paper [Common Diffusion Noise Schedules 
    # and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) section 3.4.
    if guidance_rescale != 0:
        std_pos = pred_pos.std(dim=list(range(1, pred_pos.ndim)), keepdim=True)
        std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
        # Fuse equation 15,16 for more efficient computation.
        pred_cfg *= guidance_rescale * (std_pos / std_cfg) + (1 - guidance_rescale)
    
    return pred_cfg

We wrote equation 15,16 for simpler understanding, but the computation can be fused for more efficient computation.

Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ok, I'd like to keep them separated to better seperate classifier free guidance (which is used essentially by every diffusion pipeline) from rescaling which is a bit newer and IMO.

Happy to rename noise_pred to pred_cfg

"""
Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
# std_text = torch.std(noise_pred_text)
# std_pred = torch.std(noise_pred)
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_pred * (std_text / std_pred)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred
return noise_pred


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Expand Down Expand Up @@ -559,6 +576,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -619,6 +637,11 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.

Examples:

Expand Down Expand Up @@ -705,6 +728,10 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@
"""


def rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_pred` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
# std_text = torch.std(noise_pred_text)
# std_pred = torch.std(noise_pred)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let' remove these comments.

std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_pred = noise_pred.std(dim=list(range(1, noise_pred.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_pred * (std_text / std_pred)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_pred = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_pred
return noise_pred


class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Expand Down Expand Up @@ -560,6 +576,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -620,6 +637,11 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.

Examples:

Expand Down Expand Up @@ -706,6 +728,10 @@ def __call__(
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_pred(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Expand Down
64 changes: 62 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,42 @@ def alpha_bar(time_step):
return torch.tensor(betas, dtype=torch.float32)


def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.

Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
Expand Down Expand Up @@ -122,6 +158,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_scaling (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, default `False`):
whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
This can enable the model to generate very bright and dark samples instead of limiting it to samples with
medium brightness.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this potentially deprecate offset noising?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess it's worth linking to it!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my own testing of zero terminal + cfg rescaling I believe so. I'm producing samples with great variety in contrast similar to offset noise.

Offset noising does not see to be stable and eventually diverges, and requires some other yet-to-be-invented control to adjust how much offset noise is used. The blog posts value of 0.1 works only for some number of steps, and I've found 0.1 to be far too high to for much beyond a short dreambooth style training.

gs19999-2-nayo ff7r full length shot
Typically above would produce a grey gradient background as the model tries to paint an image with mean brightness 0.5

gs09599-1-a blue volvo wagon made of legos minature lego model of a volvo wagon
Likewise sometimes difficult to get bright white backgrounds .

To reproduce the above with offset noise it may take a few attempts at guessing appropriate constant to multiple the offset noise (0.01, 0.02, etc). Zero terminal snr appears instead to simple be stable and requires no tuning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation can replace offset, see section 5.3 in the paper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does! offset noise doesn't work as well as zero SNR.

and when you enable offset noise with zero SNR, they fight each other, and the model can't learn properly.

"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -143,6 +186,8 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_scaling: str = "leading",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe sample_step_selection_mode is better?

Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, scaling is not very well chosen.

timestep_spacing maybe? think we refer mostly to how the timestep values are spaced no? Also want to try not too make it too long

rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand All @@ -159,6 +204,10 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

Expand Down Expand Up @@ -254,9 +303,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

step_ratio must not be floored to int here. It should be kept as float here for trailing to space more accurately below. See the paper, we deliberately not use floor operation to calculate interval.

So this is tricky. Unlike leading which always has the same integer spacing. trailing and linspace may have uneven integer spacing.

We need to change step() function to also make sure it supports this.

Overrall I think we should support user define any custom sampling timestep such as: [0, 99, 499, 999] in the future. step function should not expect even spacing!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's move to uneven step ratio for "trailing" - don't see a problem with this :-)

# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)

# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
if self.config.timestep_scaling == "leading":
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_scaling == "trailing":
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64).copy()
timesteps -= 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What's the point of doing copy()?
  2. Why do we use np then convert to torch tensor? Why not just use torch.arange()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy we could remove, think keeping it in numpy as mainly just a style choice

else:
raise ValueError(
f"{self.config.timestep_scaling} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
)

self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.config.steps_offset

def step(
self,
Expand Down