-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Allow dtype to be specified in Flax pipeline #600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This may be a temporary solution until #567 is addressed.
The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`.
|
The documentation is not available anymore as the PR was closed or merged. |
| ) | ||
| if latents is None: | ||
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) | ||
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative to this, I think, would be to prepare the scheduler parameters using the same dtype as the model. We can do that in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, and agree we should check model dtype here.
| ) | ||
| if latents is None: | ||
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) | ||
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, and agree we should check model dtype here.
* Fix typo in docstring. * Allow dtype to be overridden on model load. This may be a temporary solution until huggingface#567 is addressed. * Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`.
This replaces #581, which was reviewed by @patil-suraj.