Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __call__(

model = self.unet

sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ def __call__(
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -297,10 +296,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -311,10 +307,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,9 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand Down Expand Up @@ -310,16 +306,9 @@ def __call__(
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -330,10 +319,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,9 @@ def __call__(
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
Expand Down Expand Up @@ -348,13 +344,9 @@ def __call__(
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -365,14 +357,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))

latents = (init_latents_proper * mask) + (latents * (1 - mask))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -163,10 +161,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
Expand All @@ -180,11 +175,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = np.array(latents)

# call the callback, if provided
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(
model = self.unet

# sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,27 @@ def __init__(
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,29 @@ def __init__(
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

self.variance_type = variance_type

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,28 @@ def __init__(
take_from=kwargs,
)

# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max

# setable values
self.num_inference_steps: int = None
self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i)

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep

Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down
57 changes: 49 additions & 8 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -102,11 +102,36 @@ def __init__(
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)

# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()

# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.derivatives = []
self.is_scale_input_called = False

def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.

Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain

Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample

def get_lms_coefficient(self, order, t, current_order):
"""
Expand Down Expand Up @@ -154,7 +179,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
order: int = 4,
return_dict: bool = True,
Expand All @@ -165,7 +190,7 @@ def step(

Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
Expand All @@ -177,7 +202,21 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
sigma = self.sigmas[timestep]
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
Comment on lines +210 to +214
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will pop up in existing community pipelines but won't break them like an exception would. The legacy pipelines can continue using the manual scaling code 👍


if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
Expand All @@ -189,8 +228,8 @@ def step(
self.derivatives.pop(0)

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
order = min(step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
Expand All @@ -206,12 +245,14 @@ def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
Copy link
Member Author

@anton-l anton-l Oct 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really dislike what we have to do here, but unfortunately there's no good vectorized alternative to search for multiple indices and keep the order the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think it's that bad honestly


sigma = sigmas[timesteps].flatten()
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

Expand Down
Loading