1414
1515from ...models import FlaxAutoencoderKL , FlaxUNet2DConditionModel
1616from ...pipeline_flax_utils import FlaxDiffusionPipeline
17- from ...schedulers import FlaxDDIMScheduler , FlaxLMSDiscreteScheduler , FlaxPNDMScheduler
17+ from ...schedulers import FlaxDDIMScheduler , FlaxLMSDiscreteScheduler , FlaxPNDMScheduler , FlaxDPMSolverDiscreteScheduler
1818from ...utils import logging
1919from . import FlaxStableDiffusionPipelineOutput
2020from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,7 +43,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4343 unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
4444 scheduler ([`SchedulerMixin`]):
4545 A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46- [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler `].
46+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverDiscreteScheduler `].
4747 safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
4848 Classification module that estimates whether generated images could be considered offensive or harmful.
4949 Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -57,7 +57,7 @@ def __init__(
5757 text_encoder : FlaxCLIPTextModel ,
5858 tokenizer : CLIPTokenizer ,
5959 unet : FlaxUNet2DConditionModel ,
60- scheduler : Union [FlaxDDIMScheduler , FlaxPNDMScheduler , FlaxLMSDiscreteScheduler ],
60+ scheduler : Union [FlaxDDIMScheduler , FlaxPNDMScheduler , FlaxLMSDiscreteScheduler , FlaxDPMSolverDiscreteScheduler ],
6161 safety_checker : FlaxStableDiffusionSafetyChecker ,
6262 feature_extractor : CLIPFeatureExtractor ,
6363 dtype : jnp .dtype = jnp .float32 ,
0 commit comments