Skip to content

Commit 83b28b3

Browse files
committed
Add DPMSolverScheduler trait
1 parent dea4d4a commit 83b28b3

File tree

3 files changed

+86
-39
lines changed

3 files changed

+86
-39
lines changed

src/schedulers/dpmsolver.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tch::Tensor;
2+
13
use crate::schedulers::BetaSchedule;
24
use crate::schedulers::PredictionType;
35

@@ -65,3 +67,46 @@ impl Default for DPMSolverSchedulerConfig {
6567
}
6668
}
6769
}
70+
71+
pub trait DPMSolverScheduler {
72+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self;
73+
fn convert_model_output(
74+
&self,
75+
model_output: &Tensor,
76+
timestep: usize,
77+
sample: &Tensor,
78+
) -> Tensor;
79+
80+
fn first_order_update(
81+
&self,
82+
model_output: Tensor,
83+
timestep: usize,
84+
prev_timestep: usize,
85+
sample: &Tensor,
86+
) -> Tensor;
87+
88+
fn second_order_update(
89+
&self,
90+
model_output_list: &Vec<Tensor>,
91+
timestep_list: [usize; 2],
92+
prev_timestep: usize,
93+
sample: &Tensor,
94+
) -> Tensor;
95+
96+
fn third_order_update(
97+
&self,
98+
model_output_list: &Vec<Tensor>,
99+
timestep_list: [usize; 3],
100+
prev_timestep: usize,
101+
sample: &Tensor,
102+
) -> Tensor;
103+
104+
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor;
105+
106+
fn timesteps(&self) -> &[usize];
107+
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
108+
109+
110+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
111+
fn init_noise_sigma(&self) -> f64;
112+
}

src/schedulers/dpmsolver_multistep.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
1+
use super::{
2+
betas_for_alpha_bar,
3+
dpmsolver::{
4+
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
5+
},
6+
BetaSchedule, PredictionType,
7+
};
28
use tch::{kind, Kind, Tensor};
39

410
pub struct DPMSolverMultistepScheduler {
@@ -15,8 +21,8 @@ pub struct DPMSolverMultistepScheduler {
1521
pub config: DPMSolverSchedulerConfig,
1622
}
1723

18-
impl DPMSolverMultistepScheduler {
19-
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
24+
impl DPMSolverScheduler for DPMSolverMultistepScheduler {
25+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2026
let betas = match config.beta_schedule {
2127
BetaSchedule::ScaledLinear => Tensor::linspace(
2228
config.beta_start.sqrt(),
@@ -117,7 +123,7 @@ impl DPMSolverMultistepScheduler {
117123

118124
/// One step for the first-order DPM-Solver (equivalent to DDIM).
119125
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
120-
fn dpm_solver_first_order_update(
126+
fn first_order_update(
121127
&self,
122128
model_output: Tensor,
123129
timestep: usize,
@@ -139,7 +145,7 @@ impl DPMSolverMultistepScheduler {
139145
}
140146

141147
/// One step for the second-order multistep DPM-Solver.
142-
fn multistep_dpm_solver_second_order_update(
148+
fn second_order_update(
143149
&self,
144150
model_output_list: &Vec<Tensor>,
145151
timestep_list: [usize; 2],
@@ -192,7 +198,7 @@ impl DPMSolverMultistepScheduler {
192198
}
193199

194200
/// One step for the third-order multistep DPM-Solver
195-
fn multistep_dpm_solver_third_order_update(
201+
fn third_order_update(
196202
&self,
197203
model_output_list: &Vec<Tensor>,
198204
timestep_list: [usize; 3],
@@ -237,11 +243,11 @@ impl DPMSolverMultistepScheduler {
237243
}
238244
}
239245

240-
pub fn timesteps(&self) -> &[usize] {
246+
fn timesteps(&self) -> &[usize] {
241247
self.timesteps.as_slice()
242248
}
243249

244-
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
250+
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
245251
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457
246252
let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap();
247253

@@ -266,24 +272,14 @@ impl DPMSolverMultistepScheduler {
266272
|| self.lower_order_nums < 1
267273
|| lower_order_final
268274
{
269-
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
275+
self.first_order_update(model_output, timestep, prev_timestep, sample)
270276
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
271277
let timestep_list = [self.timesteps[step_index - 1], timestep];
272-
self.multistep_dpm_solver_second_order_update(
273-
&self.model_outputs,
274-
timestep_list,
275-
prev_timestep,
276-
sample,
277-
)
278+
self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
278279
} else {
279280
let timestep_list =
280281
[self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
281-
self.multistep_dpm_solver_third_order_update(
282-
&self.model_outputs,
283-
timestep_list,
284-
prev_timestep,
285-
sample,
286-
)
282+
self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
287283
};
288284

289285
if self.lower_order_nums < self.config.solver_order {
@@ -293,12 +289,16 @@ impl DPMSolverMultistepScheduler {
293289
prev_sample
294290
}
295291

296-
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
292+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
297293
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
298294
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
299295
}
300296

301-
pub fn init_noise_sigma(&self) -> f64 {
297+
fn init_noise_sigma(&self) -> f64 {
302298
self.init_noise_sigma
303299
}
300+
301+
fn scale_model_input(&self, _sample: Tensor, _timestep: usize) -> Tensor {
302+
todo!()
303+
}
304304
}

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use std::iter::repeat;
22

33
use super::{
44
betas_for_alpha_bar,
5-
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
5+
dpmsolver::{
6+
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
7+
},
68
BetaSchedule, PredictionType,
79
};
810
use tch::{kind, Kind, Tensor};
@@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler {
2325
pub config: DPMSolverSchedulerConfig,
2426
}
2527

26-
impl DPMSolverSinglestepScheduler {
27-
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
28+
impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29+
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2830
let betas = match config.beta_schedule {
2931
BetaSchedule::ScaledLinear => Tensor::linspace(
3032
config.beta_start.sqrt(),
@@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler {
141143
/// * `timestep` - current discrete timestep in the diffusion chain
142144
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
143145
/// * `sample` - current instance of sample being created by diffusion process
144-
fn dpm_solver_first_order_update(
146+
fn first_order_update(
145147
&self,
146-
model_output: &Tensor,
148+
model_output: Tensor,
147149
timestep: usize,
148150
prev_timestep: usize,
149151
sample: &Tensor,
@@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler {
171173
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
172174
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
173175
/// * `sample` - current instance of sample being created by diffusion process
174-
fn singlestep_dpm_solver_second_order_update(
176+
fn second_order_update(
175177
&self,
176178
model_output_list: &Vec<Tensor>,
177179
timestep_list: [usize; 2],
@@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler {
232234
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
233235
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
234236
/// * `sample` - current instance of sample being created by diffusion process
235-
fn singlestep_dpm_solver_third_order_update(
237+
fn third_order_update(
236238
&self,
237239
model_output_list: &Vec<Tensor>,
238240
timestep_list: [usize; 3],
@@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler {
290292
}
291293
}
292294

293-
pub fn timesteps(&self) -> &[usize] {
295+
fn timesteps(&self) -> &[usize] {
294296
self.timesteps.as_slice()
295297
}
296298

297299
/// Ensures interchangeability with schedulers that need to scale the denoising model input
298300
/// depending on the current timestep.
299-
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
301+
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
300302
sample
301303
}
302304

@@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler {
307309
/// * `model_output` - direct output from learned diffusion model
308310
/// * `timestep` - current discrete timestep in the diffusion chain
309311
/// * `sample` - current instance of sample being created by diffusion process
310-
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
312+
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
311313
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
312314
let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap();
313315

@@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler {
329331
};
330332

331333
match order {
332-
1 => self.dpm_solver_first_order_update(
333-
&self.model_outputs[self.model_outputs.len() - 1],
334+
1 => self.first_order_update(
335+
model_output,
334336
timestep,
335337
prev_timestep,
336338
&self.sample.as_ref().unwrap(),
337339
),
338-
2 => self.singlestep_dpm_solver_second_order_update(
340+
2 => self.second_order_update(
339341
&self.model_outputs,
340342
[self.timesteps[step_index - 1], self.timesteps[step_index]],
341343
prev_timestep,
342344
&self.sample.as_ref().unwrap(),
343345
),
344-
3 => self.singlestep_dpm_solver_third_order_update(
346+
3 => self.third_order_update(
345347
&self.model_outputs,
346348
[
347349
self.timesteps[step_index - 2],
@@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler {
357359
}
358360
}
359361

360-
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
362+
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
361363
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
362364
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
363365
}
364366

365-
pub fn init_noise_sigma(&self) -> f64 {
367+
fn init_noise_sigma(&self) -> f64 {
366368
self.init_noise_sigma
367369
}
368370
}

0 commit comments

Comments
 (0)