Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, EulerAScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_euler_a import CFGDenoiserForward


class StableDiffusionPipeline(DiffusionPipeline):
r"""
Expand Down Expand Up @@ -263,10 +263,16 @@ def __call__(

noise_pred = None
if isinstance(self.scheduler, EulerAScheduler):
sigma = t.reshape(1)
sigma_in = torch.cat([sigma] * 2)
# noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
# sigma = t.reshape(1) #A# potential bug: doesn't work on samples > 1
# sigma_in = torch.cat([sigma] * 2)
# # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
# # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
# c_out, c_in = [self.scheduler.append_dims(x, latent_model_input.ndim) for x in self.scheduler.get_scalings(sigma_in)]
c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)

eps = self.unet(latent_model_input * c_in, sigma_in , encoder_hidden_states=text_embeddings).sample
noise_pred = latent_model_input + eps * c_out

# noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
else:
# predict the noise residual
Expand Down
194 changes: 101 additions & 93 deletions src/diffusers/schedulers/scheduling_euler_a.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@



import math
import warnings
from typing import Optional, Tuple, Union

import numpy as np
Expand All @@ -14,97 +9,37 @@

'''
helper functions: append_zero(),
t_to_sigma(),
get_sigmas(),
append_dims(),
CFGDenoiserForward(),
get_scalings(),
DSsigma_to_t(),
DiscreteEpsDDPMDenoiserForward(),
to_d(),
get_ancestral_step()

need cleaning
'''


def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

def t_to_sigma(t,sigmas):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]


def get_sigmas(sigmas, n=None):
if n is None:
return append_zero(sigmas.flip(0))
t_max = len(sigmas) - 1 # = 999
t = torch.linspace(t_max, 0, n, device="cpu")
# t = torch.linspace(t_max, 0, n, device=sigmas.device)
return append_zero(t_to_sigma(t,sigmas))

#from k_samplers utils.py
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]

def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale,DSsigmas=None):
# x_in = torch.cat([x] * 2)#A# concat the latent
# sigma_in = torch.cat([sigma] * 2) #A# concat sigma
# cond_in = torch.cat([uncond, cond])
# uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
# uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
# return uncond + (cond - uncond) * cond_scale
noise_pred = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in)
return noise_pred

# from k_samplers sampling.py
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)


def get_scalings(sigma):
sigma_data = 1.
c_out = -sigma
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
return c_out, c_in

#DiscreteSchedule DS
def DSsigma_to_t(sigma, quantize=None,DSsigmas=None):
# quantize = self.quantize if quantize is None else quantize
quantize = False
dists = torch.abs(sigma - DSsigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
low, high = DSsigmas[low_idx], DSsigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)

# def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in,DSsigmas=None):
# # x_in = torch.cat([x] * 2)#A# concat the latent
# # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
# # cond_in = torch.cat([uncond, cond])
# # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
# # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
# # return uncond + (cond - uncond) * cond_scale
# noise_pred = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in)
# return noise_pred

def DiscreteEpsDDPMDenoiserForward(Unet,input,sigma,DSsigmas=None,**kwargs):
c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
#??? what is eps?
# eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
eps = Unet(input * c_in, DSsigma_to_t(sigma,DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample
return input + eps * c_out

# def DiscreteEpsDDPMDenoiserForward(Unet,input,sigma,DSsigmas=None,**kwargs):
# c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
# #??? what is eps?
# # eps is the predicted added noise to the image Xt for noise level t
# # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
# eps = Unet(input * c_in, DSsigma_to_t(sigma,DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample
# return input + eps * c_out


#from k_samplers sampling.py
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up


'''
Expand Down Expand Up @@ -156,7 +91,8 @@ def __init__(
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
num_inference_steps = None
num_inference_steps = None,
device = 'cuda'
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
Expand All @@ -171,8 +107,9 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.device = device
self.alphas = 1.0 - torch.from_numpy(self.betas).to(self.device)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
Expand All @@ -183,13 +120,12 @@ def __init__(
# setable values
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

# get sigmas
self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.sigmas = get_sigmas(self.DSsigmas,self.num_inference_steps)
self.sigmas = self.get_sigmas(self.DSsigmas,self.num_inference_steps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)


#A# take number of steps as input
#A# store 1) number of steps 2) timesteps 3) schedule
Expand Down Expand Up @@ -249,7 +185,7 @@ def step(
timestep: int,
timestep_prev: int,
sample:float,
generator:None,
generator: Optional[torch.Generator] = None,
# ,sigma_hat: float,
# sigma_prev: float,
# sample_hat: Union[torch.FloatTensor, np.ndarray],
Expand All @@ -274,11 +210,11 @@ def step(

"""
latents = sample
sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
sigma_down, sigma_up = self.get_ancestral_step(timestep, timestep_prev)

# if callback is not None:
# callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
d = to_d(latents, timestep, model_output)
d = self.to_d(latents, timestep, model_output)
# Euler method
dt = sigma_down - timestep
latents = latents + d * dt
Expand All @@ -296,7 +232,7 @@ def step_correct(
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
generator: None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -327,6 +263,78 @@ def step_correct(
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()


#from k_samplers sampling.py
def get_ancestral_step(self, sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up

def t_to_sigma(self, t, sigmas):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]


def append_zero(self,x):
return torch.cat([x, x.new_zeros([1])])


def get_sigmas(self, sigmas, n=None):
if n is None:
return self.append_zero(sigmas.flip(0))
t_max = len(sigmas) - 1 # = 999
device = self.device
t = torch.linspace(t_max, 0, n, device=device)
# t = torch.linspace(t_max, 0, n, device=sigmas.device)
return self.append_zero(self.t_to_sigma(t,sigmas))

#from k_samplers utils.py
def append_dims(self, x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]

# from k_samplers sampling.py
def to_d(self, x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / self.append_dims(sigma, x.ndim)


def get_scalings(self, sigma):
sigma_data = 1.
c_out = -sigma
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
return c_out, c_in


#DiscreteSchedule DS
def DSsigma_to_t(self, sigma, quantize=None):
# quantize = self.quantize if quantize is None else quantize
quantize = False
dists = torch.abs(sigma - self.DSsigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)

def prepare_input(self,latent_in, t, batch_size):
sigma = t.reshape(1) #A# potential bug: doesn't work on samples > 1

sigma_in = torch.cat([sigma] * 2 * batch_size)
# noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
# noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]

sigma_in = self.DSsigma_to_t(sigma_in)
# s_in = latent_in.new_ones([latent_in.shape[0]])
# sigma_in = sigma_in * s_in

return c_out, c_in, sigma_in