From 5db379dc24d977c50e3175e52aeff1098717101a Mon Sep 17 00:00:00 2001 From: mspronesti Date: Thu, 29 Dec 2022 23:17:55 +0100 Subject: [PATCH 1/4] feat(schedulers): implement dpm solver multistep scheduler (#17) --- src/schedulers/ddpm.rs | 2 + src/schedulers/dpmsolver_multistep.rs | 344 ++++++++++++++++++++++++++ src/schedulers/mod.rs | 1 + 3 files changed, 347 insertions(+) create mode 100644 src/schedulers/dpmsolver_multistep.rs diff --git a/src/schedulers/ddpm.rs b/src/schedulers/ddpm.rs index 68832f0..3522cfc 100644 --- a/src/schedulers/ddpm.rs +++ b/src/schedulers/ddpm.rs @@ -28,7 +28,9 @@ pub struct DDPMSchedulerConfig { pub clip_sample: bool, /// Option to clip the variance used when adding noise to the denoised sample. pub variance_type: DDPMVarianceType, + /// prediction type of the scheduler function pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model. pub train_timesteps: usize, } diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs new file mode 100644 index 0000000..b9886ec --- /dev/null +++ b/src/schedulers/dpmsolver_multistep.rs @@ -0,0 +1,344 @@ +use super::{BetaSchedule, PredictionType, betas_for_alpha_bar}; +use tch::{Tensor, kind, Kind}; + +/// The algorithm type for the solver. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverAlgorithmType { + /// Implements the algorithms defined in . + #[default] + DPMSolverPlusPlus, + /// Implements the algorithms defined in . + DPMSolver +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverType { + #[default] + Midpoint, + Heun +} + +#[derive(Debug, Clone)] +pub struct DPMSolverMultistepSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + /// sampling, and `solver_order=3` for unconditional sampling. + pub solver_order: usize, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and + /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`. + pub sample_max_value: f32, + /// The algorithm type for the solver + pub algorithm_type: DPMSolverAlgorithmType, + /// The solver type for the second-order solver. + pub solver_type: DPMSolverType, + /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. + pub lower_order_final: bool +} + +impl Default for DPMSolverMultistepSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.0001, + beta_end: 0.02, + beta_schedule: BetaSchedule::Linear, + train_timesteps: 1000, + solver_order: 2, + prediction_type: PredictionType::Epsilon, + sample_max_value: 1.0, + algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, + solver_type: DPMSolverType::Midpoint, + lower_order_final: true + } + } +} + +pub struct DPMSolverMultistepScheduler { + alphas_cumprod: Vec, + alpha_t: Vec, + sigma_t: Vec, + lambda_t: Vec, + init_noise_sigma: f64, + lower_order_nums: usize, + model_outputs: Vec, + timesteps: Vec, + pub config: DPMSolverMultistepSchedulerConfig +} + + +impl DPMSolverMultistepScheduler { + pub fn new (inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => Tensor::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps as i64, + kind::FLOAT_CPU, + ).square(), + BetaSchedule::Linear => Tensor::linspace( + config.beta_start, + config.beta_end, + config.train_timesteps as i64, + kind::FLOAT_CPU, + ), + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999), + }; + let alphas: Tensor = 1. - betas; + let alphas_cumprod = alphas.cumprod(0, Kind::Double); + + let alpha_t = alphas_cumprod.sqrt(); + let sigma_t = ((1. - &alphas_cumprod) as Tensor).sqrt(); + let lambda_t = alpha_t.log() - sigma_t.log(); + + let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; + // np.linspace(0, train_steps - 1, inference_steps + 1).round()[::-1][:-1] + let timesteps: Vec = (0..inference_steps+1) + .map(|i| (i as f64 * step).round() as usize) + // discards the 0.0 element + .skip(1) + .rev() + .collect(); + + // tch::Tensor doesn't implement `Copy` + let mut model_outputs = Vec::::new(); + for _ in 0..config.solver_order { + model_outputs.push(Tensor::new()); + } + + Self { + alphas_cumprod: Vec::::from(alphas_cumprod), + alpha_t: Vec::::from(alpha_t), + sigma_t: Vec::::from(sigma_t), + lambda_t: Vec::::from(lambda_t), + init_noise_sigma: 1., + lower_order_nums: 0, + model_outputs, + timesteps, + config + } + } + + /// Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + /// + /// DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + /// discretize an integral of the data prediction model. So we need to first convert the model output to the + /// corresponding type to match the algorithm. + /// + /// Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + /// DPM-Solver++ for both noise prediction model and data prediction model. + fn convert_model_output( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor + ) -> Tensor { + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + match self.config.prediction_type { + PredictionType::Epsilon => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + (sample - sigma_t * model_output) / alpha_t + } + PredictionType::Sample => model_output.shallow_clone(), + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + alpha_t * sample - sigma_t * model_output + } + } + // TODO: implement Dynamic thresholding + // https://arxiv.org/abs/2205.11487 + } + DPMSolverAlgorithmType::DPMSolver => match self.config.prediction_type { + PredictionType::Epsilon => model_output.shallow_clone(), + PredictionType::Sample => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + (sample - alpha_t * model_output) / sigma_t + } + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + alpha_t * model_output + sigma_t * sample + } + } + } + } + + /// One step for the first-order DPM-Solver (equivalent to DDIM). + /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. + fn dpm_solver_first_order_update( + &self, + model_output: Tensor, + timestep: usize, + prev_timestep: usize, + sample: &Tensor + ) -> Tensor { + let (lambda_t, lambda_s) = (self.lambda_t[prev_timestep], self.lambda_t[timestep]); + let (alpha_t, alpha_s) = (self.alpha_t[prev_timestep], self.alpha_t[timestep]); + let (sigma_t, sigma_s) = (self.sigma_t[prev_timestep], self.sigma_t[timestep]); + let h = lambda_t - lambda_s; + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => + (sigma_t / sigma_s) * sample - (alpha_t * ((-h).exp() - 1.0)) * model_output, + DPMSolverAlgorithmType::DPMSolver => + (alpha_t / alpha_s) * sample - (sigma_t * (h.exp() - 1.0)) * model_output + } + } + + /// One step for the second-order multistep DPM-Solver. + fn multistep_dpm_solver_second_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 2], + prev_timestep: usize, + sample: &Tensor + ) -> Tensor { + + let (t, s0, s1) = (prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2]); + let (m0, m1) = (model_output_list[model_output_list.len() - 1].as_ref(), model_output_list[model_output_list.len() - 2].as_ref()); + let (lambda_t, lambda_s0, lambda_s1) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1); + let r0 = h_0 / h; + let (d0, d1) = (m0, (1.0 / r0) * (m0 - m1)); + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => match self.config.solver_type { + // See https://arxiv.org/abs/2211.01095 for detailed derivations + DPMSolverType::Midpoint => { + (sigma_t / sigma_s0) * sample + - (alpha_t * ((-h).exp() - 1.0)) * d0 + - 0.5 * (alpha_t * ((-h).exp() - 1.0)) * d1 + } + DPMSolverType::Heun => { + (sigma_t / sigma_s0) * sample + - (alpha_t * ((-h).exp() - 1.0)) * d0 + + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 + } + }, + DPMSolverAlgorithmType::DPMSolver => match self.config.solver_type { + // See https://arxiv.org/abs/2206.00927 for detailed derivations + DPMSolverType::Midpoint => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - 0.5 * (sigma_t * (h.exp() - 1.0)) * d1 + } + DPMSolverType::Heun => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1 + } + } + } + } + + /// One step for the third-order multistep DPM-Solver + fn multistep_dpm_solver_third_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 3], + prev_timestep: usize, + sample: &Tensor + ) -> Tensor { + let (t, s0, s1, s2) = + (prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2], timestep_list[timestep_list.len() - 3]); + let (m0, m1, m2) = ( + model_output_list[model_output_list.len() - 1].as_ref(), + model_output_list[model_output_list.len() - 2].as_ref(), + model_output_list[model_output_list.len() - 3].as_ref() + ); + let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0, h_1) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2); + let (r0, r1) = (h_0 / h, h_1 / h); + let d0 = m0; + let (d1_0, d1_1) = ((1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)); + let d1 = &d1_0 + (r0 / (r0 + r1)) * (&d1_0 - &d1_1); + let d2 = (1.0 / (r0 + r1)) * (d1_0 - d1_1); + + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + // See https://arxiv.org/abs/2206.00927 for detailed derivations + (sigma_t / sigma_s0) * sample + - (alpha_t * ((-h).exp() - 1.0)) * d0 + + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 + - (alpha_t * (((-h).exp() - 1.0 + h) / h.powi(2) - 0.5)) * d2 + } + DPMSolverAlgorithmType::DPMSolver => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1 + - (sigma_t * ((h.exp() - 1.0 - h) / h.powi(2) - 0.5)) * d2 + } + } + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { + let step_index = self.timesteps + .iter() + .position(|&t| t == timestep) + .unwrap(); + + let prev_timestep = if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] }; + let lower_order_final = (step_index == self.timesteps.len() - 1) + && self.config.lower_order_final && self.timesteps.len() < 15; + let lower_order_second = (step_index == self.timesteps.len() - 2) + && self.config.lower_order_final && self.timesteps.len() < 15; + + let model_output = self.convert_model_output(model_output, timestep, sample); + for i in 0..self.config.solver_order - 1 { + self.model_outputs[i] = self.model_outputs[i + 1].shallow_clone(); + } + // self.model_outputs[-1] = model_output + let m = self.model_outputs.len(); + self.model_outputs[m - 1] = model_output.shallow_clone(); + + let prev_sample = if self.config.solver_order == 1 || self.lower_order_nums < 1 || lower_order_final { + self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample + ) + } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { + let timestep_list = [self.timesteps[step_index - 1], timestep]; + self.multistep_dpm_solver_second_order_update( + &self.model_outputs, timestep_list, prev_timestep, sample + ) + } else { + let timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; + self.multistep_dpm_solver_third_order_update( + &self.model_outputs, timestep_list, prev_timestep, sample + ) + }; + + if self.lower_order_nums < self.config.solver_order { + self.lower_order_nums += 1; + } + + prev_sample + } + + pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs index 67571c5..df1dde1 100644 --- a/src/schedulers/mod.rs +++ b/src/schedulers/mod.rs @@ -7,6 +7,7 @@ use tch::Tensor; pub mod ddim; pub mod ddpm; +pub mod dpmsolver_multistep; /// This represents how beta ranges from its minimum value to the maximum /// during training. From 24b5985dead51a8afe9ad2dfcf213a0f172cae60 Mon Sep 17 00:00:00 2001 From: mspronesti Date: Thu, 29 Dec 2022 23:34:05 +0100 Subject: [PATCH 2/4] chore(schedulers): apply rustfmt --- src/schedulers/dpmsolver_multistep.rs | 106 ++++++++++++++++---------- 1 file changed, 65 insertions(+), 41 deletions(-) diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index b9886ec..ccb5124 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -1,5 +1,5 @@ -use super::{BetaSchedule, PredictionType, betas_for_alpha_bar}; -use tch::{Tensor, kind, Kind}; +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use tch::{kind, Kind, Tensor}; /// The algorithm type for the solver. /// @@ -9,14 +9,14 @@ pub enum DPMSolverAlgorithmType { #[default] DPMSolverPlusPlus, /// Implements the algorithms defined in . - DPMSolver + DPMSolver, } #[derive(Default, Debug, Clone, PartialEq, Eq)] pub enum DPMSolverType { #[default] Midpoint, - Heun + Heun, } #[derive(Debug, Clone)] @@ -43,7 +43,7 @@ pub struct DPMSolverMultistepSchedulerConfig { pub solver_type: DPMSolverType, /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. - pub lower_order_final: bool + pub lower_order_final: bool, } impl Default for DPMSolverMultistepSchedulerConfig { @@ -58,7 +58,7 @@ impl Default for DPMSolverMultistepSchedulerConfig { sample_max_value: 1.0, algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, solver_type: DPMSolverType::Midpoint, - lower_order_final: true + lower_order_final: true, } } } @@ -72,19 +72,19 @@ pub struct DPMSolverMultistepScheduler { lower_order_nums: usize, model_outputs: Vec, timesteps: Vec, - pub config: DPMSolverMultistepSchedulerConfig + pub config: DPMSolverMultistepSchedulerConfig, } - impl DPMSolverMultistepScheduler { - pub fn new (inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self { + pub fn new(inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), config.beta_end.sqrt(), config.train_timesteps as i64, kind::FLOAT_CPU, - ).square(), + ) + .square(), BetaSchedule::Linear => Tensor::linspace( config.beta_start, config.beta_end, @@ -100,9 +100,9 @@ impl DPMSolverMultistepScheduler { let sigma_t = ((1. - &alphas_cumprod) as Tensor).sqrt(); let lambda_t = alpha_t.log() - sigma_t.log(); - let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; + let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; // np.linspace(0, train_steps - 1, inference_steps + 1).round()[::-1][:-1] - let timesteps: Vec = (0..inference_steps+1) + let timesteps: Vec = (0..inference_steps + 1) .map(|i| (i as f64 * step).round() as usize) // discards the 0.0 element .skip(1) @@ -124,7 +124,7 @@ impl DPMSolverMultistepScheduler { lower_order_nums: 0, model_outputs, timesteps, - config + config, } } @@ -140,7 +140,7 @@ impl DPMSolverMultistepScheduler { &self, model_output: &Tensor, timestep: usize, - sample: &Tensor + sample: &Tensor, ) -> Tensor { match self.config.algorithm_type { DPMSolverAlgorithmType::DPMSolverPlusPlus => { @@ -183,17 +183,19 @@ impl DPMSolverMultistepScheduler { model_output: Tensor, timestep: usize, prev_timestep: usize, - sample: &Tensor + sample: &Tensor, ) -> Tensor { let (lambda_t, lambda_s) = (self.lambda_t[prev_timestep], self.lambda_t[timestep]); let (alpha_t, alpha_s) = (self.alpha_t[prev_timestep], self.alpha_t[timestep]); let (sigma_t, sigma_s) = (self.sigma_t[prev_timestep], self.sigma_t[timestep]); let h = lambda_t - lambda_s; match self.config.algorithm_type { - DPMSolverAlgorithmType::DPMSolverPlusPlus => - (sigma_t / sigma_s) * sample - (alpha_t * ((-h).exp() - 1.0)) * model_output, - DPMSolverAlgorithmType::DPMSolver => + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + (sigma_t / sigma_s) * sample - (alpha_t * ((-h).exp() - 1.0)) * model_output + } + DPMSolverAlgorithmType::DPMSolver => { (alpha_t / alpha_s) * sample - (sigma_t * (h.exp() - 1.0)) * model_output + } } } @@ -203,12 +205,19 @@ impl DPMSolverMultistepScheduler { model_output_list: &Vec, timestep_list: [usize; 2], prev_timestep: usize, - sample: &Tensor + sample: &Tensor, ) -> Tensor { - - let (t, s0, s1) = (prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2]); - let (m0, m1) = (model_output_list[model_output_list.len() - 1].as_ref(), model_output_list[model_output_list.len() - 2].as_ref()); - let (lambda_t, lambda_s0, lambda_s1) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]); + let (t, s0, s1) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + ); + let (m0, m1) = ( + model_output_list[model_output_list.len() - 1].as_ref(), + model_output_list[model_output_list.len() - 2].as_ref(), + ); + let (lambda_t, lambda_s0, lambda_s1) = + (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]); let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); let (h, h_0) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1); @@ -223,8 +232,7 @@ impl DPMSolverMultistepScheduler { - 0.5 * (alpha_t * ((-h).exp() - 1.0)) * d1 } DPMSolverType::Heun => { - (sigma_t / sigma_s0) * sample - - (alpha_t * ((-h).exp() - 1.0)) * d0 + (sigma_t / sigma_s0) * sample - (alpha_t * ((-h).exp() - 1.0)) * d0 + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 } }, @@ -234,7 +242,7 @@ impl DPMSolverMultistepScheduler { (alpha_t / alpha_s0) * sample - (sigma_t * (h.exp() - 1.0)) * d0 - 0.5 * (sigma_t * (h.exp() - 1.0)) * d1 - } + }, DPMSolverType::Heun => { (alpha_t / alpha_s0) * sample - (sigma_t * (h.exp() - 1.0)) * d0 @@ -250,16 +258,21 @@ impl DPMSolverMultistepScheduler { model_output_list: &Vec, timestep_list: [usize; 3], prev_timestep: usize, - sample: &Tensor + sample: &Tensor, ) -> Tensor { - let (t, s0, s1, s2) = - (prev_timestep, timestep_list[timestep_list.len() - 1], timestep_list[timestep_list.len() - 2], timestep_list[timestep_list.len() - 3]); + let (t, s0, s1, s2) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + timestep_list[timestep_list.len() - 3], + ); let (m0, m1, m2) = ( model_output_list[model_output_list.len() - 1].as_ref(), model_output_list[model_output_list.len() - 2].as_ref(), - model_output_list[model_output_list.len() - 3].as_ref() + model_output_list[model_output_list.len() - 3].as_ref(), ); - let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]); + let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = + (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]); let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); let (h, h_0, h_1) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2); @@ -296,11 +309,14 @@ impl DPMSolverMultistepScheduler { .position(|&t| t == timestep) .unwrap(); - let prev_timestep = if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] }; + let prev_timestep = + if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] }; let lower_order_final = (step_index == self.timesteps.len() - 1) - && self.config.lower_order_final && self.timesteps.len() < 15; + && self.config.lower_order_final + && self.timesteps.len() < 15; let lower_order_second = (step_index == self.timesteps.len() - 2) - && self.config.lower_order_final && self.timesteps.len() < 15; + && self.config.lower_order_final + && self.timesteps.len() < 15; let model_output = self.convert_model_output(model_output, timestep, sample); for i in 0..self.config.solver_order - 1 { @@ -310,19 +326,27 @@ impl DPMSolverMultistepScheduler { let m = self.model_outputs.len(); self.model_outputs[m - 1] = model_output.shallow_clone(); - let prev_sample = if self.config.solver_order == 1 || self.lower_order_nums < 1 || lower_order_final { - self.dpm_solver_first_order_update( - model_output, timestep, prev_timestep, sample - ) + let prev_sample = if self.config.solver_order == 1 + || self.lower_order_nums < 1 + || lower_order_final + { + self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { let timestep_list = [self.timesteps[step_index - 1], timestep]; self.multistep_dpm_solver_second_order_update( - &self.model_outputs, timestep_list, prev_timestep, sample + &self.model_outputs, + timestep_list, + prev_timestep, + sample, ) } else { - let timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; + let timestep_list = + [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; self.multistep_dpm_solver_third_order_update( - &self.model_outputs, timestep_list, prev_timestep, sample + &self.model_outputs, + timestep_list, + prev_timestep, + sample, ) }; From 4530351f7362ee235201b1b005232db8d0d66852 Mon Sep 17 00:00:00 2001 From: mspronesti Date: Fri, 30 Dec 2022 16:33:31 +0100 Subject: [PATCH 3/4] chore(schedulers): fix rustfmt --- src/schedulers/dpmsolver_multistep.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index ccb5124..9836d51 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -172,7 +172,7 @@ impl DPMSolverMultistepScheduler { let sigma_t = self.sigma_t[timestep]; alpha_t * model_output + sigma_t * sample } - } + }, } } @@ -242,13 +242,13 @@ impl DPMSolverMultistepScheduler { (alpha_t / alpha_s0) * sample - (sigma_t * (h.exp() - 1.0)) * d0 - 0.5 * (sigma_t * (h.exp() - 1.0)) * d1 - }, + } DPMSolverType::Heun => { (alpha_t / alpha_s0) * sample - (sigma_t * (h.exp() - 1.0)) * d0 - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1 } - } + }, } } @@ -285,8 +285,7 @@ impl DPMSolverMultistepScheduler { match self.config.algorithm_type { DPMSolverAlgorithmType::DPMSolverPlusPlus => { // See https://arxiv.org/abs/2206.00927 for detailed derivations - (sigma_t / sigma_s0) * sample - - (alpha_t * ((-h).exp() - 1.0)) * d0 + (sigma_t / sigma_s0) * sample - (alpha_t * ((-h).exp() - 1.0)) * d0 + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 - (alpha_t * (((-h).exp() - 1.0 + h) / h.powi(2) - 0.5)) * d2 } @@ -304,6 +303,7 @@ impl DPMSolverMultistepScheduler { } pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { + #[rustfmt::skip] let step_index = self.timesteps .iter() .position(|&t| t == timestep) From 5b1de331c3072f73d5d0457fa6c6964ca5b94a01 Mon Sep 17 00:00:00 2001 From: mspronesti Date: Sat, 31 Dec 2022 17:12:07 +0100 Subject: [PATCH 4/4] chore(schedulers): reformat code * remove rustfmt:skip * add permalinks to the python implementation * re-run rustfmt --- src/schedulers/dpmsolver_multistep.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index 9836d51..7aa1cfc 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -12,6 +12,9 @@ pub enum DPMSolverAlgorithmType { DPMSolver, } +/// The solver type for the second-order solver. +/// The solver type slightly affects the sample quality, especially for +/// small number of steps. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub enum DPMSolverType { #[default] @@ -101,7 +104,7 @@ impl DPMSolverMultistepScheduler { let lambda_t = alpha_t.log() - sigma_t.log(); let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; - // np.linspace(0, train_steps - 1, inference_steps + 1).round()[::-1][:-1] + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L199-L204 let timesteps: Vec = (0..inference_steps + 1) .map(|i| (i as f64 * step).round() as usize) // discards the 0.0 element @@ -109,7 +112,6 @@ impl DPMSolverMultistepScheduler { .rev() .collect(); - // tch::Tensor doesn't implement `Copy` let mut model_outputs = Vec::::new(); for _ in 0..config.solver_order { model_outputs.push(Tensor::new()); @@ -303,11 +305,8 @@ impl DPMSolverMultistepScheduler { } pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { - #[rustfmt::skip] - let step_index = self.timesteps - .iter() - .position(|&t| t == timestep) - .unwrap(); + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457 + let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); let prev_timestep = if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] };