-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[SchedulerDesign] Alternative scheduler design #711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||||
| from ...configuration_utils import FrozenDict | ||||||
| from ...models import AutoencoderKL, UNet2DConditionModel | ||||||
| from ...pipeline_utils import DiffusionPipeline | ||||||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||||||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, SchedulerType | ||||||
| from ...utils import logging | ||||||
| from . import StableDiffusionPipelineOutput | ||||||
| from .safety_checker import StableDiffusionSafetyChecker | ||||||
|
|
@@ -259,8 +259,8 @@ def __call__( | |||||
| timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||||||
|
|
||||||
| # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||||||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||||||
| latents = latents * self.scheduler.sigmas[0] | ||||||
| if self.scheduler.type == SchedulerType.CONTINUOUS: | ||||||
| latents = latents * self.scheduler.init_sigma | ||||||
|
|
||||||
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||||||
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||||||
|
|
@@ -274,10 +274,9 @@ def __call__( | |||||
| for i, t in enumerate(self.progress_bar(timesteps_tensor)): | ||||||
| # expand the latents if we are doing classifier free guidance | ||||||
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||||||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||||||
| sigma = self.scheduler.sigmas[i] | ||||||
| # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||||||
| latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||||||
|
|
||||||
| if self.scheduler.type == SchedulerType.CONTINUOUS and self.model.config.trained_scheduler_type == SchedulerType.DISCRETE: | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Here I'm very open to hear better suggestions @anton-l (if you can think about something that doesn't require changes to DDIM or DDPM but that's cleaner) |
||||||
| latent_model_input = self.scheduler.scale_(latent_model_input) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this |
||||||
|
|
||||||
| # predict the noise residual | ||||||
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||||||
|
|
@@ -288,10 +287,7 @@ def __call__( | |||||
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||||||
|
|
||||||
| # compute the previous noisy sample x_t -> x_t-1 | ||||||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||||||
| latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | ||||||
| else: | ||||||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||||||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be considered a bug correction IMO |
||||||
|
|
||||||
| # call the callback, if provided | ||||||
| if callback is not None and i % callback_steps == 0: | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
|
|
||
| from ..configuration_utils import ConfigMixin, register_to_config | ||
| from ..utils import BaseOutput | ||
| from .scheduling_utils import SchedulerMixin | ||
| from .scheduling_utils import SchedulerMixin, SchedulerType | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -66,6 +66,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
| option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
|
|
||
| """ | ||
| type = SchedulerType.CONTINUOUS | ||
|
|
||
| @register_to_config | ||
| def __init__( | ||
|
|
@@ -178,7 +179,8 @@ def step( | |
| When returning a tuple, the first element is the sample tensor. | ||
|
|
||
| """ | ||
| sigma = self.sigmas[timestep] | ||
| index = (self.config.num_train_timesteps - timestep) // (self.config.num_train_timesteps // self.num_inference_steps) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 I like this bit, having the scheduler responsible for figuring out what to do with the timestep instead of having the pipeline keep track of how schedulers interpret their arguments.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
| sigma = self.sigmas[index] | ||
|
|
||
| # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
| pred_original_sample = sample - sigma * model_output | ||
|
|
@@ -190,8 +192,8 @@ def step( | |
| self.derivatives.pop(0) | ||
|
|
||
| # 3. Compute linear multistep coefficients | ||
| order = min(timestep + 1, order) | ||
| lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] | ||
| order = min(index + 1, order) | ||
| lms_coeffs = [self.get_lms_coefficient(order, index, curr_order) for curr_order in range(order)] | ||
|
|
||
| # 4. Compute previous sample based on the derivatives path | ||
| prev_sample = sample + sum( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A step in the right direction. I'd love to get rid of the
ifentirely, but as long as we have it, defining a scheduler.type enum is much preferable to isinstance!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anton-l can also be convinced here to change it to something that might be cleaner (again if we don't have to force a function upon DDIM or DDPM)
Overall, I feel quite strongly about the following though: