@@ -99,11 +99,14 @@ def __init__(
9999 self .alphas = 1.0 - self .betas
100100 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
101101
102- self .sigmas = ((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5
102+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
103+ sigmas = np .concatenate ([sigmas [::- 1 ], [0.0 ]]).astype (np .float32 )
104+ self .sigmas = torch .from_numpy (sigmas )
103105
104106 # setable values
105107 self .num_inference_steps = None
106- self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ] # to be consistent has to be smaller than sigmas by 1
108+ timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
109+ self .timesteps = torch .from_numpy (timesteps )
107110 self .derivatives = []
108111
109112 def get_lms_coefficient (self , order , t , current_order ):
@@ -137,17 +140,14 @@ def set_timesteps(self, num_inference_steps: int):
137140 the number of diffusion steps used when generating samples with a pre-trained model.
138141 """
139142 self .num_inference_steps = num_inference_steps
140- timesteps = np .linspace (self .config .num_train_timesteps - 1 , 0 , num_inference_steps , dtype = float )
141143
142- low_idx = np .floor (timesteps ).astype (int )
143- high_idx = np .ceil (timesteps ).astype (int )
144- frac = np .mod (timesteps , 1.0 )
144+ timesteps = np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
145145 sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
146- sigmas = ( 1 - frac ) * sigmas [ low_idx ] + frac * sigmas [ high_idx ]
146+ sigmas = np . interp ( timesteps , np . arange ( 0 , len ( sigmas )), sigmas )
147147 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
148148 self .sigmas = torch .from_numpy (sigmas )
149+ self .timesteps = torch .from_numpy (timesteps )
149150
150- self .timesteps = timesteps .astype (int )
151151 self .derivatives = []
152152
153153 def step (
0 commit comments