-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add experimental Heun scheduler #1356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -477,7 +477,6 @@ def __call__( | |
|
|
||
| # 4. Prepare timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
| timesteps = self.scheduler.timesteps | ||
|
|
||
| # 5. Prepare latent variables | ||
| num_channels_latents = self.unet.in_channels | ||
|
|
@@ -496,7 +495,9 @@ def __call__( | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | ||
|
|
||
| # 7. Denoising loop | ||
| for i, t in enumerate(self.progress_bar(timesteps)): | ||
| i = 0 | ||
| t = self.scheduler.timesteps[0] | ||
| while t > 0: | ||
| # expand the latents if we are doing classifier free guidance | ||
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
|
@@ -510,12 +511,14 @@ def __call__( | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
|
||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
| latents, t = self.scheduler.step(noise_pred, t, latents, return_dict=False, **extra_step_kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm this makes the step function a bit more black-box - not suuuper happy about it at this point of the library, but maybe at some point we have to let the scheduler adapt |
||
|
|
||
| # call the callback, if provided | ||
| if callback is not None and i % callback_steps == 0: | ||
| callback(i, t, latents) | ||
|
|
||
| i += 1 | ||
|
|
||
| # 8. Post-processing | ||
| image = self.decode_latents(latents) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| # Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from ..configuration_utils import ConfigMixin, register_to_config | ||
| from .scheduling_utils import SchedulerMixin, SchedulerOutput | ||
|
|
||
|
|
||
| class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
| """ | ||
| Args: | ||
| Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original | ||
| k-diffusion implementation by Katherine Crowson: | ||
| https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 | ||
| [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
| function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
| [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
| [`~ConfigMixin.from_config`] functions. | ||
| num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the | ||
| starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): | ||
| the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | ||
| `linear` or `scaled_linear`. | ||
| trained_betas (`np.ndarray`, optional): | ||
| option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
| options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, | ||
| `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. | ||
| tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. | ||
| """ | ||
|
|
||
| @register_to_config | ||
| def __init__( | ||
| self, | ||
| num_train_timesteps: int = 1000, | ||
| beta_start: float = 0.00085, # sensible defaults | ||
| beta_end: float = 0.012, | ||
| beta_schedule: str = "linear", | ||
| trained_betas: Optional[np.ndarray] = None, | ||
| ): | ||
| if trained_betas is not None: | ||
| self.betas = torch.from_numpy(trained_betas) | ||
| elif beta_schedule == "linear": | ||
| self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
| elif beta_schedule == "scaled_linear": | ||
| # this schedule is very specific to the latent diffusion model. | ||
| self.betas = ( | ||
| torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
| ) | ||
| else: | ||
| raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
|
|
||
| self.alphas = 1.0 - self.betas | ||
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
|
|
||
| # set all values | ||
| self.set_timesteps(num_train_timesteps, None, num_train_timesteps) | ||
|
|
||
| def scale_model_input( | ||
| self, | ||
| sample: torch.FloatTensor, | ||
| timestep: Union[float, torch.FloatTensor], | ||
| ) -> torch.FloatTensor: | ||
| """ | ||
| Args: | ||
| Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | ||
| current timestep. | ||
| sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep | ||
| Returns: | ||
| `torch.FloatTensor`: scaled input sample | ||
| """ | ||
| step_index = (self.timesteps == timestep).nonzero().item() | ||
| sigma = self.sigmas[step_index] | ||
| sample = sample / ((sigma**2 + 1) ** 0.5) | ||
| return sample | ||
|
|
||
| def set_timesteps( | ||
| self, | ||
| num_inference_steps: int, | ||
| device: Union[str, torch.device] = None, | ||
| num_train_timesteps: Optional[int] = None, | ||
| ): | ||
| """ | ||
| Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
| Args: | ||
| num_inference_steps (`int`): | ||
| the number of diffusion steps used when generating samples with a pre-trained model. | ||
| device (`str` or `torch.device`, optional): | ||
| the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | ||
| """ | ||
| self.num_inference_steps = num_inference_steps | ||
|
|
||
| num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps | ||
|
|
||
| timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
|
|
||
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
| sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
|
|
||
| timesteps = torch.from_numpy(timesteps) | ||
|
|
||
| # standard deviation of the initial noise distribution | ||
| self.init_noise_sigma = sigmas[0] | ||
|
|
||
| if str(device).startswith("mps"): | ||
| # mps does not support float64 | ||
| self.timesteps = timesteps.to(device, dtype=torch.float32) | ||
| else: | ||
| self.timesteps = timesteps.to(device=device) | ||
|
|
||
| # empty dt and derivative | ||
| self.prev_derivative = None | ||
| self.dt = None | ||
|
|
||
| @property | ||
| def state_in_first_order(self): | ||
| return self.dt is None | ||
|
|
||
| def step( | ||
| self, | ||
| model_output: Union[torch.FloatTensor, np.ndarray], | ||
| timestep: Union[float, torch.FloatTensor], | ||
| sample: Union[torch.FloatTensor, np.ndarray], | ||
| return_dict: bool = True, | ||
| ) -> Union[SchedulerOutput, Tuple]: | ||
| """ | ||
| Args: | ||
| Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
| process from the learned model outputs (most often the predicted noise). | ||
| model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep | ||
| (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): | ||
| current instance of sample being created by diffusion process. | ||
| return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| Returns: | ||
| [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: | ||
| [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| returning a tuple, the first element is the sample tensor. | ||
| """ | ||
| step_index = (self.timesteps == timestep).nonzero().item() | ||
|
|
||
| if self.state_in_first_order: | ||
| sigma = self.sigmas[step_index] | ||
| step_index += 1 | ||
| sigma_next = self.sigmas[step_index] | ||
| sigma_hat = sigma | ||
| else: | ||
| # 2nd order / Heun's method | ||
| sigma = self.sigmas[step_index - 1] | ||
| sigma_next = self.sigmas[step_index] | ||
| sigma_hat = sigma_next | ||
|
|
||
| # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
| pred_original_sample = sample - sigma_hat * model_output | ||
|
|
||
| # 2. Convert to an ODE derivative | ||
| derivative = (sample - pred_original_sample) / sigma_hat | ||
| if self.state_in_first_order: | ||
| # 3. 1st order derivative | ||
| dt = sigma_next - sigma_hat | ||
|
|
||
| # store for 2nd order step | ||
| self.sample = sample | ||
| self.prev_derivative = derivative | ||
| self.dt = dt | ||
| else: | ||
| # 2. 2nd order / Heun's method | ||
| derivative = (self.prev_derivative + derivative) / 2 | ||
|
|
||
| # 3. Retrieve 1st order derivative | ||
| dt = self.dt | ||
|
|
||
| # free dt and derivative | ||
| # Note, this puts the scheduler in "first order mode" | ||
| self.prev_derivative = None | ||
| self.dt = None | ||
|
|
||
| prev_sample = self.sample + derivative * dt | ||
| print(f"step_index: {step_index}, state_in_first_order: {self.state_in_first_order}, sigma: {sigma}, sigma_next: {sigma_next}, sigma_hat: {sigma_hat}, dt: {dt}") | ||
|
|
||
| if not return_dict: | ||
| return (prev_sample, self.timesteps[step_index]) | ||
|
|
||
| return SchedulerOutput(prev_sample=prev_sample, timestep=self.timesteps[step_index]) | ||
|
Comment on lines
+196
to
+199
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar implementation as in #1336 except that we return the next |
||
|
|
||
| def add_noise( | ||
| self, | ||
| original_samples: Union[torch.FloatTensor, np.ndarray], | ||
| noise: Union[torch.FloatTensor, np.ndarray], | ||
| timesteps: Union[torch.IntTensor, np.ndarray], | ||
| ) -> Union[torch.FloatTensor, np.ndarray]: | ||
| # Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
| self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
| self.timesteps = self.timesteps.to(original_samples.device) | ||
| sigma = self.sigmas[timesteps].flatten() | ||
| while len(sigma.shape) < len(original_samples.shape): | ||
| sigma = sigma.unsqueeze(-1) | ||
|
|
||
| noisy_samples = original_samples + noise * sigma | ||
| return noisy_samples | ||
|
|
||
| def __len__(self): | ||
| return self.config.num_train_timesteps | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep
iandt, which is very ugly.iis only required for the callback.