Skip to content

Commit dfcee10

Browse files
committed
add add_noise method for dpmsolver
1 parent 98a309e commit dfcee10

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,5 +440,28 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
440440
"""
441441
return sample
442442

443+
def add_noise(
444+
self,
445+
original_samples: torch.FloatTensor,
446+
noise: torch.FloatTensor,
447+
timesteps: torch.IntTensor,
448+
) -> torch.FloatTensor:
449+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
450+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
451+
timesteps = timesteps.to(original_samples.device)
452+
453+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
454+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
455+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
456+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
457+
458+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
459+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
460+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
461+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
462+
463+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
464+
return noisy_samples
465+
443466
def __len__(self):
444467
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_dpmsolver_discrete_flax.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax.numpy as jnp
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -533,5 +533,22 @@ def scale_model_input(
533533
"""
534534
return sample
535535

536+
def add_noise(
537+
self,
538+
original_samples: jnp.ndarray,
539+
noise: jnp.ndarray,
540+
timesteps: jnp.ndarray,
541+
) -> jnp.ndarray:
542+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
543+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
544+
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
545+
546+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
547+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
548+
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
549+
550+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
551+
return noisy_samples
552+
536553
def __len__(self):
537554
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)