File tree Expand file tree Collapse file tree 3 files changed +12
-5
lines changed Expand file tree Collapse file tree 3 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -282,7 +282,12 @@ def add_noise(
282282 noise : torch .FloatTensor ,
283283 timesteps : torch .IntTensor ,
284284 ) -> torch .FloatTensor :
285- timesteps = timesteps .to (self .alphas_cumprod .device )
285+ if self .alphas_cumprod .device != original_samples .device :
286+ self .alphas_cumprod = self .alphas_cumprod .to (original_samples .device )
287+
288+ if timesteps .device != original_samples .device :
289+ timesteps = timesteps .to (original_samples .device )
290+
286291 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
287292 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
288293 while len (sqrt_alpha_prod .shape ) < len (original_samples .shape ):
Original file line number Diff line number Diff line change @@ -268,15 +268,19 @@ def add_noise(
268268 noise : torch .FloatTensor ,
269269 timesteps : torch .IntTensor ,
270270 ) -> torch .FloatTensor :
271- timesteps = timesteps .to (self .alphas_cumprod .device )
271+ if self .alphas_cumprod .device != original_samples .device :
272+ self .alphas_cumprod = self .alphas_cumprod .to (original_samples .device )
273+
274+ if timesteps .device != original_samples .device :
275+ timesteps = timesteps .to (original_samples .device )
272276
273277 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
274278 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
275279 while len (sqrt_alpha_prod .shape ) < len (original_samples .shape ):
276280 sqrt_alpha_prod = sqrt_alpha_prod .unsqueeze (- 1 )
277281
278282 sqrt_one_minus_alpha_prod = (1 - self .alphas_cumprod [timesteps ]) ** 0.5
279- sqrt_one_minus_alpha_prod .flatten ()
283+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .flatten ()
280284 while len (sqrt_one_minus_alpha_prod .shape ) < len (original_samples .shape ):
281285 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod .unsqueeze (- 1 )
282286
Original file line number Diff line number Diff line change @@ -387,8 +387,6 @@ def add_noise(
387387 if timesteps .device != original_samples .device :
388388 timesteps = timesteps .to (original_samples .device )
389389
390- timesteps = timesteps .to (self .alphas_cumprod .device )
391-
392390 sqrt_alpha_prod = self .alphas_cumprod [timesteps ] ** 0.5
393391 sqrt_alpha_prod = sqrt_alpha_prod .flatten ()
394392 while len (sqrt_alpha_prod .shape ) < len (original_samples .shape ):
You can’t perform that action at this time.
0 commit comments