@@ -2,7 +2,9 @@ use std::iter::repeat;
22
33use super :: {
44 betas_for_alpha_bar,
5- dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
5+ dpmsolver:: {
6+ DPMSolverAlgorithmType , DPMSolverScheduler , DPMSolverSchedulerConfig , DPMSolverType ,
7+ } ,
68 BetaSchedule , PredictionType ,
79} ;
810use tch:: { kind, Kind , Tensor } ;
@@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler {
2325 pub config : DPMSolverSchedulerConfig ,
2426}
2527
26- impl DPMSolverSinglestepScheduler {
27- pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
28+ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29+ fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
2830 let betas = match config. beta_schedule {
2931 BetaSchedule :: ScaledLinear => Tensor :: linspace (
3032 config. beta_start . sqrt ( ) ,
@@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler {
141143 /// * `timestep` - current discrete timestep in the diffusion chain
142144 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
143145 /// * `sample` - current instance of sample being created by diffusion process
144- fn dpm_solver_first_order_update (
146+ fn first_order_update (
145147 & self ,
146- model_output : & Tensor ,
148+ model_output : Tensor ,
147149 timestep : usize ,
148150 prev_timestep : usize ,
149151 sample : & Tensor ,
@@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler {
171173 /// * `timestep_list` - current and latter discrete timestep in the diffusion chain
172174 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
173175 /// * `sample` - current instance of sample being created by diffusion process
174- fn singlestep_dpm_solver_second_order_update (
176+ fn second_order_update (
175177 & self ,
176178 model_output_list : & Vec < Tensor > ,
177179 timestep_list : [ usize ; 2 ] ,
@@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler {
232234 /// * `timestep_list` - current and latter discrete timestep in the diffusion chain
233235 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
234236 /// * `sample` - current instance of sample being created by diffusion process
235- fn singlestep_dpm_solver_third_order_update (
237+ fn third_order_update (
236238 & self ,
237239 model_output_list : & Vec < Tensor > ,
238240 timestep_list : [ usize ; 3 ] ,
@@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler {
290292 }
291293 }
292294
293- pub fn timesteps ( & self ) -> & [ usize ] {
295+ fn timesteps ( & self ) -> & [ usize ] {
294296 self . timesteps . as_slice ( )
295297 }
296298
297299 /// Ensures interchangeability with schedulers that need to scale the denoising model input
298300 /// depending on the current timestep.
299- pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
301+ fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
300302 sample
301303 }
302304
@@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler {
307309 /// * `model_output` - direct output from learned diffusion model
308310 /// * `timestep` - current discrete timestep in the diffusion chain
309311 /// * `sample` - current instance of sample being created by diffusion process
310- pub fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
312+ fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
311313 // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
312314 let step_index: usize = self . timesteps . iter ( ) . position ( |& t| t == timestep) . unwrap ( ) ;
313315
@@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler {
329331 } ;
330332
331333 match order {
332- 1 => self . dpm_solver_first_order_update (
333- & self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
334+ 1 => self . first_order_update (
335+ model_output ,
334336 timestep,
335337 prev_timestep,
336338 & self . sample . as_ref ( ) . unwrap ( ) ,
337339 ) ,
338- 2 => self . singlestep_dpm_solver_second_order_update (
340+ 2 => self . second_order_update (
339341 & self . model_outputs ,
340342 [ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
341343 prev_timestep,
342344 & self . sample . as_ref ( ) . unwrap ( ) ,
343345 ) ,
344- 3 => self . singlestep_dpm_solver_third_order_update (
346+ 3 => self . third_order_update (
345347 & self . model_outputs ,
346348 [
347349 self . timesteps [ step_index - 2 ] ,
@@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler {
357359 }
358360 }
359361
360- pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
362+ fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
361363 self . alphas_cumprod [ timestep] . sqrt ( ) * original_samples. to_owned ( )
362364 + ( 1.0 - self . alphas_cumprod [ timestep] ) . sqrt ( ) * noise
363365 }
364366
365- pub fn init_noise_sigma ( & self ) -> f64 {
367+ fn init_noise_sigma ( & self ) -> f64 {
366368 self . init_noise_sigma
367369 }
368370}
0 commit comments