Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class ConfigMixin:
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")

if hasattr(self, "type") and self.type is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a scheduler type")

kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, SchedulerType
from ...utils import logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -259,8 +259,8 @@ def __call__(
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
if self.scheduler.type == SchedulerType.CONTINUOUS:
latents = latents * self.scheduler.init_sigma
Comment on lines +262 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A step in the right direction. I'd love to get rid of the if entirely, but as long as we have it, defining a scheduler.type enum is much preferable to isinstance!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anton-l can also be convinced here to change it to something that might be cleaner (again if we don't have to force a function upon DDIM or DDPM)

Overall, I feel quite strongly about the following though:

  • Let's not make easy schedulers more complex because we'd like to support newer, more complex schedulers
  • Forcing every scheduler to implement a certain method that can grow arbitrary in complexity is much worse that educating if statements with a nice comment.


# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -274,10 +274,9 @@ def __call__(
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE:
if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE:

Here I'm very open to hear better suggestions @anton-l (if you can think about something that doesn't require changes to DDIM or DDPM but that's cleaner)

latent_model_input = self.scheduler.scale_(latent_model_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this scale_ implemented?


# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -288,10 +287,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be considered a bug correction IMO


# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerType
else:
from ..utils.dummy_pt_objects import * # noqa F403

Expand Down
10 changes: 6 additions & 4 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerType


@dataclass
Expand Down Expand Up @@ -66,6 +66,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.

"""
type = SchedulerType.CONTINUOUS

@register_to_config
def __init__(
Expand Down Expand Up @@ -178,7 +179,8 @@ def step(
When returning a tuple, the first element is the sample tensor.

"""
sigma = self.sigmas[timestep]
index = (self.config.num_train_timesteps - timestep) // (self.config.num_train_timesteps // self.num_inference_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I like this bit, having the scheduler responsible for figuring out what to do with the timestep instead of having the pipeline keep track of how schedulers interpret their arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since timestep is float (don't mind the wrong type annotation and doc for now), I think the surest way to get its index is self.timesteps.where(timesteps), as timesteps are linearly interpolated here.
But where() just looks like a hack, while we can refactor the timesteps and scheduler properly.

sigma = self.sigmas[index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
Expand All @@ -190,8 +192,8 @@ def step(
self.derivatives.pop(0)

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
order = min(index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, index, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/schedulers/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional
from enum import Enum

import torch

Expand All @@ -22,6 +24,12 @@
SCHEDULER_CONFIG_NAME = "scheduler_config.json"


class SchedulerType(Enum):

CONTINUOUS = "continous"
DISCRETE = "discrete"


@dataclass
class SchedulerOutput(BaseOutput):
"""
Expand All @@ -42,6 +50,7 @@ class SchedulerMixin:
"""

config_name = SCHEDULER_CONFIG_NAME
type: Optional[SchedulerType] = None

def set_format(self, tensor_format="pt"):
warnings.warn(
Expand Down