Skip to content

Commit fb2fbab

Browse files
authored
Allow dtype to be specified in Flax pipeline (#600)
* Fix typo in docstring. * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`.
1 parent fb03aad commit fb2fbab

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/diffusers/configuration_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,12 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret
154154
155155
"""
156156
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
157-
158157
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
159158

159+
# Allow dtype to be specified on initialization
160+
if "dtype" in unused_kwargs:
161+
init_dict["dtype"] = unused_kwargs.pop("dtype")
162+
160163
model = cls(**init_dict)
161164

162165
if return_unused_kwargs:

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
3030
Tokenizer of class
3131
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
3232
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
33-
scheduler ([`FlaxSchedulerMixin`]):
33+
scheduler ([`SchedulerMixin`]):
3434
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
3535
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
3636
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
@@ -157,7 +157,7 @@ def __call__(
157157
self.unet.sample_size,
158158
)
159159
if latents is None:
160-
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
160+
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
161161
else:
162162
if latents.shape != latents_shape:
163163
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

0 commit comments

Comments
 (0)