-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Flax support for Stable Diffusion 2 #1423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| "DownBlock2D", | ||
| ) | ||
| up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") | ||
| only_cross_attention: Union[bool, Tuple[bool]] = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
little , syntax error
|
thanks @pcuenca ♥ it is working 🎉 |
|
I converted to flax I put it https://huggingface.co/flax/stable-diffusion-2 if someone needs |
|
Super cool! Should we maybe add one slow test for SD-2? |
Slice values taken from my Ampere GPU.
Note that the expected values are taken from the PyTorch results. This ensures the Flax and PyTorch versions are not too far off.
Co-authored-by: Patrick von Platen <[email protected]>
I added a couple of integration tests to ensure that the output from the Flax UNet is close enough to the output from PyTorch. This works within a tolerance of The tests may fail on the real testing hardware though (V100 and TPU). I'll adapt if that's the case. |
| [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], | ||
| [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], | ||
| [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], | ||
| [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the same slices from the torch float16 GPU test. JAX results are within reasonable tolerance on the same GPU, despite being done in bfloat16.
In this case the use of parameterized is slow as each test triggers a compilation. If we want to make it faster we should replace them with a single test and a loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
| ) | ||
| @require_torch_gpu | ||
| def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): | ||
| model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this!
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice PR! Thanks for adding all those tests :-)
|
Merging then to deploy the backend. Thanks! |
* Flax: start adapting to Stable Diffusion 2 * More changes. * attention_head_dim can be a tuple. * Fix typos * Add simple SD 2 integration test. Slice values taken from my Ampere GPU. * Add simple UNet integration tests for Flax. Note that the expected values are taken from the PyTorch results. This ensures the Flax and PyTorch versions are not too far off. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Typos and style * Tests: verify jax is available. * Style * Make flake happy * Remove typo. * Simple Flax SD 2 pipeline tests. * Import order * Remove unused import. Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: @camenduru
* Flax: start adapting to Stable Diffusion 2 * More changes. * attention_head_dim can be a tuple. * Fix typos * Add simple SD 2 integration test. Slice values taken from my Ampere GPU. * Add simple UNet integration tests for Flax. Note that the expected values are taken from the PyTorch results. This ensures the Flax and PyTorch versions are not too far off. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Typos and style * Tests: verify jax is available. * Style * Make flake happy * Remove typo. * Simple Flax SD 2 pipeline tests. * Import order * Remove unused import. Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: @camenduru
No description provided.