1515# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
1616
1717import math
18- from typing import Optional , Tuple , Union , List
18+ from typing import List , Optional , Tuple , Union
1919
2020import numpy as np
2121import torch
@@ -151,7 +151,9 @@ def __init__(
151151 self .num_inference_steps = None
152152 timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
153153 self .timesteps = torch .from_numpy (timesteps )
154- self .model_outputs = [None ,] * self .solver_order
154+ self .model_outputs = [
155+ None ,
156+ ] * self .solver_order
155157 self .lower_order_nums = 0
156158
157159 def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
@@ -165,16 +167,20 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
165167 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
166168 """
167169 self .num_inference_steps = num_inference_steps
168- timesteps = np .linspace (0 , self .num_train_timesteps - 1 , num_inference_steps + 1 ).round ()[::- 1 ][:- 1 ].copy ().astype (np .int64 )
170+ timesteps = (
171+ np .linspace (0 , self .num_train_timesteps - 1 , num_inference_steps + 1 )
172+ .round ()[::- 1 ][:- 1 ]
173+ .copy ()
174+ .astype (np .int64 )
175+ )
169176 self .timesteps = torch .from_numpy (timesteps ).to (device )
170- self .model_outputs = [None ,] * self .solver_order
177+ self .model_outputs = [
178+ None ,
179+ ] * self .solver_order
171180 self .lower_order_nums = 0
172181
173182 def convert_model_output (
174- self ,
175- model_output : torch .FloatTensor ,
176- timestep : int ,
177- sample : torch .FloatTensor
183+ self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
178184 ) -> torch .FloatTensor :
179185 """
180186 TODO
@@ -184,9 +190,11 @@ def convert_model_output(
184190 x0_pred = (sample - sigma_t * model_output ) / alpha_t
185191 if self .thresholding :
186192 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
187- p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487).
193+ p = 0.995 # A hyperparameter in the paper of "Imagen" (https://arxiv.org/abs/2205.11487).
188194 s = torch .quantile (torch .abs (x0_pred ).reshape ((x0_pred .shape [0 ], - 1 )), p , dim = 1 )
189- s = torch .maximum (s , self .sample_max_value * torch .ones_like (s ).to (s .device ))[(...,) + (None ,)* (x0_pred .ndim - 1 )]
195+ s = torch .maximum (s , self .sample_max_value * torch .ones_like (s ).to (s .device ))[
196+ (...,) + (None ,) * (x0_pred .ndim - 1 )
197+ ]
190198 x0_pred = torch .clamp (x0_pred , - s , s ) / s
191199 return x0_pred
192200 else :
@@ -207,15 +215,9 @@ def dpm_solver_first_order_update(
207215 sigma_t , sigma_s = self .sigma_t [prev_timestep ], self .sigma_t [timestep ]
208216 h = lambda_t - lambda_s
209217 if self .predict_x0 :
210- x_t = (
211- (sigma_t / sigma_s ) * sample
212- - (alpha_t * (torch .exp (- h ) - 1. )) * model_output
213- )
218+ x_t = (sigma_t / sigma_s ) * sample - (alpha_t * (torch .exp (- h ) - 1.0 )) * model_output
214219 else :
215- x_t = (
216- (alpha_t / alpha_s ) * sample
217- - (sigma_t * (torch .exp (h ) - 1. )) * model_output
218- )
220+ x_t = (alpha_t / alpha_s ) * sample - (sigma_t * (torch .exp (h ) - 1.0 )) * model_output
219221 return x_t
220222
221223 def multistep_dpm_solver_second_order_update (
@@ -235,32 +237,32 @@ def multistep_dpm_solver_second_order_update(
235237 sigma_t , sigma_s0 = self .sigma_t [t ], self .sigma_t [s0 ]
236238 h , h_0 = lambda_t - lambda_s0 , lambda_s0 - lambda_s1
237239 r0 = h_0 / h
238- D0 , D1 = m0 , (1. / r0 ) * (m0 - m1 )
240+ D0 , D1 = m0 , (1.0 / r0 ) * (m0 - m1 )
239241 if self .predict_x0 :
240- if self .solver_type == ' dpm_solver' :
242+ if self .solver_type == " dpm_solver" :
241243 x_t = (
242244 (sigma_t / sigma_s0 ) * sample
243- - (alpha_t * (torch .exp (- h ) - 1. )) * D0
244- - 0.5 * (alpha_t * (torch .exp (- h ) - 1. )) * D1
245+ - (alpha_t * (torch .exp (- h ) - 1.0 )) * D0
246+ - 0.5 * (alpha_t * (torch .exp (- h ) - 1.0 )) * D1
245247 )
246- elif self .solver_type == ' taylor' :
248+ elif self .solver_type == " taylor" :
247249 x_t = (
248250 (sigma_t / sigma_s0 ) * sample
249- - (alpha_t * (torch .exp (- h ) - 1. )) * D0
250- + (alpha_t * ((torch .exp (- h ) - 1. ) / h + 1. )) * D1
251+ - (alpha_t * (torch .exp (- h ) - 1.0 )) * D0
252+ + (alpha_t * ((torch .exp (- h ) - 1.0 ) / h + 1.0 )) * D1
251253 )
252254 else :
253- if self .solver_type == ' dpm_solver' :
255+ if self .solver_type == " dpm_solver" :
254256 x_t = (
255257 (alpha_t / alpha_s0 ) * sample
256- - (sigma_t * (torch .exp (h ) - 1. )) * D0
257- - 0.5 * (sigma_t * (torch .exp (h ) - 1. )) * D1
258+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D0
259+ - 0.5 * (sigma_t * (torch .exp (h ) - 1.0 )) * D1
258260 )
259- elif self .solver_type == ' taylor' :
261+ elif self .solver_type == " taylor" :
260262 x_t = (
261263 (alpha_t / alpha_s0 ) * sample
262- - (sigma_t * (torch .exp (h ) - 1. )) * D0
263- - (sigma_t * ((torch .exp (h ) - 1. ) / h - 1. )) * D1
264+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D0
265+ - (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
264266 )
265267 return x_t
266268
@@ -276,28 +278,33 @@ def multistep_dpm_solver_third_order_update(
276278 """
277279 t , s0 , s1 , s2 = prev_timestep , timestep_list [- 1 ], timestep_list [- 2 ], timestep_list [- 3 ]
278280 m0 , m1 , m2 = model_output_list [- 1 ], model_output_list [- 2 ], model_output_list [- 3 ]
279- lambda_t , lambda_s0 , lambda_s1 , lambda_s2 = self .lambda_t [t ], self .lambda_t [s0 ], self .lambda_t [s1 ], self .lambda_t [s2 ]
281+ lambda_t , lambda_s0 , lambda_s1 , lambda_s2 = (
282+ self .lambda_t [t ],
283+ self .lambda_t [s0 ],
284+ self .lambda_t [s1 ],
285+ self .lambda_t [s2 ],
286+ )
280287 alpha_t , alpha_s0 = self .alpha_t [t ], self .alpha_t [s0 ]
281288 sigma_t , sigma_s0 = self .sigma_t [t ], self .sigma_t [s0 ]
282289 h , h_0 , h_1 = lambda_t - lambda_s0 , lambda_s0 - lambda_s1 , lambda_s1 - lambda_s2
283290 r0 , r1 = h_0 / h , h_1 / h
284291 D0 = m0
285- D1_0 , D1_1 = (1. / r0 ) * (m0 - m1 ), (1. / r1 ) * (m1 - m2 )
292+ D1_0 , D1_1 = (1.0 / r0 ) * (m0 - m1 ), (1.0 / r1 ) * (m1 - m2 )
286293 D1 = D1_0 + (r0 / (r0 + r1 )) * (D1_0 - D1_1 )
287- D2 = (1. / (r0 + r1 )) * (D1_0 - D1_1 )
294+ D2 = (1.0 / (r0 + r1 )) * (D1_0 - D1_1 )
288295 if self .predict_x0 :
289296 x_t = (
290297 (sigma_t / sigma_s0 ) * sample
291- - (alpha_t * (torch .exp (- h ) - 1. )) * D0
292- + (alpha_t * ((torch .exp (- h ) - 1. ) / h + 1. )) * D1
293- - (alpha_t * ((torch .exp (- h ) - 1. + h ) / h ** 2 - 0.5 )) * D2
298+ - (alpha_t * (torch .exp (- h ) - 1.0 )) * D0
299+ + (alpha_t * ((torch .exp (- h ) - 1.0 ) / h + 1.0 )) * D1
300+ - (alpha_t * ((torch .exp (- h ) - 1.0 + h ) / h ** 2 - 0.5 )) * D2
294301 )
295302 else :
296303 x_t = (
297304 (alpha_t / alpha_s0 ) * sample
298- - (sigma_t * (torch .exp (h ) - 1. )) * D0
299- - (sigma_t * ((torch .exp (h ) - 1. ) / h - 1. )) * D1
300- - (sigma_t * ((torch .exp (h ) - 1. - h ) / h ** 2 - 0.5 )) * D2
305+ - (sigma_t * (torch .exp (h ) - 1.0 )) * D0
306+ - (sigma_t * ((torch .exp (h ) - 1.0 ) / h - 1.0 )) * D1
307+ - (sigma_t * ((torch .exp (h ) - 1.0 - h ) / h ** 2 - 0.5 )) * D2
301308 )
302309 return x_t
303310
@@ -336,7 +343,7 @@ def step(
336343 denoise_final = (step_index == len (self .timesteps ) - 1 ) and self .denoise_final
337344 denoise_second = (step_index == len (self .timesteps ) - 2 ) and self .denoise_final
338345
339- model_output = self .convert_model_output (model_output , timestep , sample )
346+ model_output = self .convert_model_output (model_output , timestep , sample )
340347 for i in range (self .solver_order - 1 ):
341348 self .model_outputs [i ] = self .model_outputs [i + 1 ]
342349 self .model_outputs [- 1 ] = model_output
@@ -345,10 +352,14 @@ def step(
345352 prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
346353 elif self .solver_order == 2 or self .lower_order_nums < 2 or denoise_second :
347354 timestep_list = [self .timesteps [step_index - 1 ], timestep ]
348- prev_sample = self .multistep_dpm_solver_second_order_update (self .model_outputs , timestep_list , prev_timestep , sample )
355+ prev_sample = self .multistep_dpm_solver_second_order_update (
356+ self .model_outputs , timestep_list , prev_timestep , sample
357+ )
349358 else :
350359 timestep_list = [self .timesteps [step_index - 2 ], self .timesteps [step_index - 1 ], timestep ]
351- prev_sample = self .multistep_dpm_solver_third_order_update (self .model_outputs , timestep_list , prev_timestep , sample )
360+ prev_sample = self .multistep_dpm_solver_third_order_update (
361+ self .model_outputs , timestep_list , prev_timestep , sample
362+ )
352363
353364 if self .lower_order_nums < self .solver_order :
354365 self .lower_order_nums += 1
0 commit comments