Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPM2DiscreteScheduler,
DPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler
)
from .training_utils import EMAModel
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, \
DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -51,7 +53,8 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[DDIMScheduler, PNDMScheduler, DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler,
LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
Expand Down Expand Up @@ -320,7 +323,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -331,7 +334,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, \
DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -63,7 +65,12 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[
DDIMScheduler, PNDMScheduler, DPM2DiscreteScheduler,
DPM2AncestralDiscreteScheduler, LMSDiscreteScheduler,
EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
Expand Down Expand Up @@ -319,8 +326,13 @@ def __call__(
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
Expand All @@ -346,7 +358,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Expand All @@ -357,7 +369,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample

# call the callback, if provided
if callback is not None and i % callback_steps == 0:
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,11 @@

if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
from .scheduling_dpm2_discrete import DPM2DiscreteScheduler
from .scheduling_dpm2_ancestral_discrete import DPM2AncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler

else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
5 changes: 4 additions & 1 deletion src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __init__(
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(
self, sample: torch.FloatTensor,
*args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Expand Down Expand Up @@ -200,6 +202,7 @@ def step(
self,
model_output: torch.FloatTensor,
timestep: int,
step_index:int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def __init__(

self.variance_type = variance_type

def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
def scale_model_input(
self, sample: torch.FloatTensor,
*args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Expand Down Expand Up @@ -211,6 +213,7 @@ def step(
self,
model_output: torch.FloatTensor,
timestep: int,
step_index:int,
sample: torch.FloatTensor,
predict_epsilon=True,
generator=None,
Expand Down
Loading