Skip to content

Commit 845a7d3

Browse files
committed
add jax/flax version dpm-solver
1 parent e110262 commit 845a7d3

File tree

6 files changed

+509
-5
lines changed

6 files changed

+509
-5
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from .schedulers import (
101101
FlaxDDIMScheduler,
102102
FlaxDDPMScheduler,
103+
FlaxDPMSolverDiscreteScheduler,
103104
FlaxKarrasVeScheduler,
104105
FlaxLMSDiscreteScheduler,
105106
FlaxPNDMScheduler,

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
1616
from ...pipeline_flax_utils import FlaxDiffusionPipeline
17-
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
17+
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxDPMSolverDiscreteScheduler
1818
from ...utils import logging
1919
from . import FlaxStableDiffusionPipelineOutput
2020
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,7 +43,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4343
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
4444
scheduler ([`SchedulerMixin`]):
4545
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46-
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
46+
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverDiscreteScheduler`].
4747
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
4848
Classification module that estimates whether generated images could be considered offensive or harmful.
4949
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -57,7 +57,7 @@ def __init__(
5757
text_encoder: FlaxCLIPTextModel,
5858
tokenizer: CLIPTokenizer,
5959
unet: FlaxUNet2DConditionModel,
60-
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
60+
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverDiscreteScheduler],
6161
safety_checker: FlaxStableDiffusionSafetyChecker,
6262
feature_extractor: CLIPFeatureExtractor,
6363
dtype: jnp.dtype = jnp.float32,

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
if is_flax_available():
3737
from .scheduling_ddim_flax import FlaxDDIMScheduler
3838
from .scheduling_ddpm_flax import FlaxDDPMScheduler
39+
from .scheduling_dpmsolver_discrete_flax import FlaxDPMSolverDiscreteScheduler
3940
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
4041
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
4142
from .scheduling_pndm_flax import FlaxPNDMScheduler

src/diffusers/schedulers/scheduling_dpmsolver_discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ def __init__(
104104
beta_end: float = 0.02,
105105
beta_schedule: str = "linear",
106106
trained_betas: Optional[np.ndarray] = None,
107-
solver_order: int = 3,
107+
solver_order: int = 2,
108108
predict_x0: bool = True,
109109
thresholding: bool = False,
110110
sample_max_value: float = 1.0,
111111
solver_type: str = "dpm_solver",
112-
denoise_final: bool = True,
112+
denoise_final: bool = False,
113113
):
114114
if trained_betas is not None:
115115
self.betas = torch.from_numpy(trained_betas)

0 commit comments

Comments
 (0)