File tree Expand file tree Collapse file tree 4 files changed +16
-18
lines changed Expand file tree Collapse file tree 4 files changed +16
-18
lines changed Original file line number Diff line number Diff line change @@ -300,11 +300,9 @@ def add_noise(
300300 noise : torch .FloatTensor ,
301301 timesteps : torch .IntTensor ,
302302 ) -> torch .FloatTensor :
303- if self .alphas_cumprod .device != original_samples .device :
304- self .alphas_cumprod = self .alphas_cumprod .to (original_samples .device )
305-
306- if timesteps .device != original_samples .device :
307- timesteps = timesteps .to (original_samples .device )
303+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
304+ self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
305+ timesteps = timesteps .to (original_samples .device )
308306
309307 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
310308 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
Original file line number Diff line number Diff line change @@ -294,11 +294,9 @@ def add_noise(
294294 noise : torch .FloatTensor ,
295295 timesteps : torch .IntTensor ,
296296 ) -> torch .FloatTensor :
297- if self .alphas_cumprod .device != original_samples .device :
298- self .alphas_cumprod = self .alphas_cumprod .to (original_samples .device )
299-
300- if timesteps .device != original_samples .device :
301- timesteps = timesteps .to (original_samples .device )
297+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
298+ self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
299+ timesteps = timesteps .to (original_samples .device )
302300
303301 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
304302 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
Original file line number Diff line number Diff line change @@ -257,9 +257,13 @@ def add_noise(
257257 noise : torch .FloatTensor ,
258258 timesteps : torch .FloatTensor ,
259259 ) -> torch .FloatTensor :
260- sigmas = self .sigmas .to (original_samples .device )
261- schedule_timesteps = self .timesteps .to (original_samples .device )
260+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
261+ self .sigmas = self .sigmas .to (device = original_samples .device , dtype = original_samples .dtype )
262+ self .timesteps = self .timesteps .to (original_samples .device )
262263 timesteps = timesteps .to (original_samples .device )
264+
265+ schedule_timesteps = self .timesteps
266+
263267 if isinstance (timesteps , torch .IntTensor ) or isinstance (timesteps , torch .LongTensor ):
264268 deprecate (
265269 "timesteps as indices" ,
@@ -273,7 +277,7 @@ def add_noise(
273277 else :
274278 step_indices = [(schedule_timesteps == t ).nonzero ().item () for t in timesteps ]
275279
276- sigma = sigmas [step_indices ].flatten ()
280+ sigma = self . sigmas [step_indices ].flatten ()
277281 while len (sigma .shape ) < len (original_samples .shape ):
278282 sigma = sigma .unsqueeze (- 1 )
279283
Original file line number Diff line number Diff line change @@ -400,11 +400,9 @@ def add_noise(
400400 noise : torch .FloatTensor ,
401401 timesteps : torch .IntTensor ,
402402 ) -> torch .Tensor :
403- if self .alphas_cumprod .device != original_samples .device :
404- self .alphas_cumprod = self .alphas_cumprod .to (original_samples .device )
405-
406- if timesteps .device != original_samples .device :
407- timesteps = timesteps .to (original_samples .device )
403+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
404+ self .alphas_cumprod = self .alphas_cumprod .to (device = original_samples .device , dtype = original_samples .dtype )
405+ timesteps = timesteps .to (original_samples .device )
408406
409407 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
410408 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
You can’t perform that action at this time.
0 commit comments