@@ -85,26 +85,28 @@ def __init__(
8585 )
8686
8787 if trained_betas is not None :
88- self .betas = np . asarray (trained_betas )
88+ self .betas = torch . from_numpy (trained_betas )
8989 if beta_schedule == "linear" :
90- self .betas = np .linspace (beta_start , beta_end , num_train_timesteps , dtype = np .float32 )
90+ self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
9191 elif beta_schedule == "scaled_linear" :
9292 # this schedule is very specific to the latent diffusion model.
9393 self .betas = (
94- np .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = np .float32 ) ** 2
94+ torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
9595 )
9696 else :
9797 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
9898
99- self .alphas = np . array ( 1.0 - self .betas , dtype = np . float32 )
100- self .alphas_cumprod = np .cumprod (self .alphas , axis = 0 )
99+ self .alphas = 1.0 - self .betas
100+ self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
101101
102- sigmas = ((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5
102+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
103+ sigmas = np .concatenate ([sigmas [::- 1 ], [0.0 ]]).astype (np .float32 )
103104 self .sigmas = torch .from_numpy (sigmas )
104105
105106 # setable values
106107 self .num_inference_steps = None
107- self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ]
108+ timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
109+ self .timesteps = torch .from_numpy (timesteps )
108110 self .derivatives = []
109111
110112 def get_lms_coefficient (self , order , t , current_order ):
@@ -138,16 +140,13 @@ def set_timesteps(self, num_inference_steps: int):
138140 the number of diffusion steps used when generating samples with a pre-trained model.
139141 """
140142 self .num_inference_steps = num_inference_steps
141- timesteps = np .linspace (self .config .num_train_timesteps - 1 , 0 , num_inference_steps , dtype = float )
142143
143- low_idx = np .floor (timesteps ).astype (int )
144- high_idx = np .ceil (timesteps ).astype (int )
145- frac = np .mod (timesteps , 1.0 )
144+ timesteps = np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
146145 sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
147- sigmas = ( 1 - frac ) * sigmas [ low_idx ] + frac * sigmas [ high_idx ]
146+ sigmas = np . interp ( timesteps , np . arange ( 0 , len ( sigmas )), sigmas )
148147 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
149148 self .sigmas = torch .from_numpy (sigmas )
150- self .timesteps = timesteps
149+ self .timesteps = torch . from_numpy ( timesteps )
151150
152151 self .derivatives = []
153152
0 commit comments