Skip to content

Conversation

@pcuenca
Copy link
Member

@pcuenca pcuenca commented Nov 25, 2022

No description provided.

@pcuenca pcuenca marked this pull request as draft November 25, 2022 16:02
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 25, 2022

The documentation is not available anymore as the PR was closed or merged.

"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

little , syntax error

@camenduru
Copy link
Contributor

thanks @pcuenca ♥ it is working 🎉

@camenduru
Copy link
Contributor

I converted to flax I put it https://huggingface.co/flax/stable-diffusion-2 if someone needs

@pcuenca pcuenca marked this pull request as ready for review November 27, 2022 22:02
@patrickvonplaten
Copy link
Contributor

Super cool!

Should we maybe add one slow test for SD-2?

pcuenca and others added 3 commits November 28, 2022 17:28
Slice values taken from my Ampere GPU.
Note that the expected values are taken from the PyTorch results. This
ensures the Flax and PyTorch versions are not too far off.
Co-authored-by: Patrick von Platen <[email protected]>
@pcuenca
Copy link
Member Author

pcuenca commented Nov 28, 2022

Should we maybe add one slow test for SD-2?

I added a couple of integration tests to ensure that the output from the Flax UNet is close enough to the output from PyTorch. This works within a tolerance of 1e-2 in my hardware, despite PyTorch values being generated in float16 and Flax ones in bfloat16. This sounds like the UNet implementation appears to be correct.

The tests may fail on the real testing hardware though (V100 and TPU). I'll adapt if that's the case.

Comment on lines +48 to +51
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the same slices from the torch float16 GPU test. JAX results are within reasonable tolerance on the same GPU, despite being done in bfloat16.

In this case the use of parameterized is slow as each test triggers a compilation. If we want to make it faster we should replace them with a single test and a loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

)
@require_torch_gpu
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice PR! Thanks for adding all those tests :-)

@pcuenca
Copy link
Member Author

pcuenca commented Nov 29, 2022

Merging then to deploy the backend. Thanks!

@pcuenca pcuenca merged commit 4d1e4e2 into main Nov 29, 2022
@pcuenca pcuenca deleted the flax-sd-2 branch November 29, 2022 11:33
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* Flax: start adapting to Stable Diffusion 2

* More changes.

* attention_head_dim can be a tuple.

* Fix typos

* Add simple SD 2 integration test.

Slice values taken from my Ampere GPU.

* Add simple UNet integration tests for Flax.

Note that the expected values are taken from the PyTorch results. This
ensures the Flax and PyTorch versions are not too far off.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Typos and style

* Tests: verify jax is available.

* Style

* Make flake happy

* Remove typo.

* Simple Flax SD 2 pipeline tests.

* Import order

* Remove unused import.

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: @camenduru
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Flax: start adapting to Stable Diffusion 2

* More changes.

* attention_head_dim can be a tuple.

* Fix typos

* Add simple SD 2 integration test.

Slice values taken from my Ampere GPU.

* Add simple UNet integration tests for Flax.

Note that the expected values are taken from the PyTorch results. This
ensures the Flax and PyTorch versions are not too far off.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Typos and style

* Tests: verify jax is available.

* Style

* Make flake happy

* Remove typo.

* Simple Flax SD 2 pipeline tests.

* Import order

* Remove unused import.

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: @camenduru
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants