1414
1515from ...models import FlaxAutoencoderKL , FlaxUNet2DConditionModel
1616from ...pipeline_flax_utils import FlaxDiffusionPipeline
17- from ...schedulers import FlaxDDIMScheduler , FlaxLMSDiscreteScheduler , FlaxPNDMScheduler
17+ from ...schedulers import (
18+ FlaxDDIMScheduler ,
19+ FlaxDPMSolverMultistepScheduler ,
20+ FlaxLMSDiscreteScheduler ,
21+ FlaxPNDMScheduler ,
22+ )
1823from ...utils import logging
1924from . import FlaxStableDiffusionPipelineOutput
2025from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4348 unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
4449 scheduler ([`SchedulerMixin`]):
4550 A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46- [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
51+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
52+ [`FlaxDPMSolverMultistepScheduler`].
4753 safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
4854 Classification module that estimates whether generated images could be considered offensive or harmful.
4955 Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -57,7 +63,9 @@ def __init__(
5763 text_encoder : FlaxCLIPTextModel ,
5864 tokenizer : CLIPTokenizer ,
5965 unet : FlaxUNet2DConditionModel ,
60- scheduler : Union [FlaxDDIMScheduler , FlaxPNDMScheduler , FlaxLMSDiscreteScheduler ],
66+ scheduler : Union [
67+ FlaxDDIMScheduler , FlaxPNDMScheduler , FlaxLMSDiscreteScheduler , FlaxDPMSolverMultistepScheduler
68+ ],
6169 safety_checker : FlaxStableDiffusionSafetyChecker ,
6270 feature_extractor : CLIPFeatureExtractor ,
6371 dtype : jnp .dtype = jnp .float32 ,
0 commit comments