-
Notifications
You must be signed in to change notification settings - Fork 6.5k
8k Stable Diffusion with tiled VAE #1441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
49b61c8
63d5661
626fb88
14215bf
4b6536d
ac8b1c2
c99dbb6
0a96a81
2b0454d
307fd12
2a403a1
20387d0
541f275
928c6d3
94781b6
6adfedf
8e7b8c2
a03d41f
68d9b29
99a3733
9728662
4eb4981
907788a
48793b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self): | |
| ) | ||
| return vqvae, unet | ||
|
|
||
| @slow | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something we need?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this test takes a minute and this model has pretty much 0 usage, so disabling for fast
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the decorator is on |
||
| def test_audio_diffusion(self): | ||
| device = "cpu" # ensure determinism for the device-dependent torch.Generator | ||
| mel = Mel() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -419,6 +419,29 @@ def test_stable_diffusion_vae_slicing(self): | |
| # there is a small discrepancy at image borders vs. full batch decode | ||
| assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 | ||
|
|
||
| def test_stable_diffusion_vae_tiling(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice test! |
||
| device = "cpu" # ensure determinism for the device-dependent torch.Generator | ||
| components = self.get_dummy_components() | ||
|
|
||
| # make sure here that pndm scheduler skips prk | ||
| components["safety_checker"] = None | ||
| sd_pipe = StableDiffusionPipeline(**components) | ||
| sd_pipe = sd_pipe.to(device) | ||
| sd_pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| prompt = "A painting of a squirrel eating a burger" | ||
|
|
||
| # Test that tiled decode at 512x512 yields the same result as the non-tiled decode | ||
| generator = torch.Generator(device=device).manual_seed(0) | ||
| output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") | ||
|
|
||
| # make sure tiled vae decode yields the same result | ||
| sd_pipe.enable_vae_tiling() | ||
| generator = torch.Generator(device=device).manual_seed(0) | ||
| output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1 | ||
|
|
||
| def test_stable_diffusion_negative_prompt(self): | ||
| device = "cpu" # ensure determinism for the device-dependent torch.Generator | ||
| components = self.get_dummy_components() | ||
|
|
@@ -699,6 +722,58 @@ def test_stable_diffusion_vae_slicing(self): | |
| # There is a small discrepancy at the image borders vs. a fully batched version. | ||
| assert np.abs(image_sliced - image).max() < 1e-2 | ||
|
|
||
| def test_stable_diffusion_vae_tiling(self): | ||
| torch.cuda.reset_peak_memory_stats() | ||
| model_id = "CompVis/stable-diffusion-v1-4" | ||
| pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) | ||
| pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
| pipe.enable_attention_slicing() | ||
| pipe.unet = pipe.unet.to(memory_format=torch.channels_last) | ||
| pipe.vae = pipe.vae.to(memory_format=torch.channels_last) | ||
|
|
||
| prompt = "a photograph of an astronaut riding a horse" | ||
|
|
||
| # enable vae tiling | ||
| pipe.enable_vae_tiling() | ||
| generator = torch.Generator(device=torch_device).manual_seed(0) | ||
| with torch.autocast(torch_device): | ||
| output_chunked = pipe( | ||
| [prompt], | ||
| width=640, | ||
| height=640, | ||
| generator=generator, | ||
| guidance_scale=7.5, | ||
| num_inference_steps=2, | ||
| output_type="numpy", | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| image_chunked = output_chunked.images | ||
|
|
||
| mem_bytes = torch.cuda.max_memory_allocated() | ||
| torch.cuda.reset_peak_memory_stats() | ||
| # make sure that less than 4 GB is allocated | ||
| assert mem_bytes < 4e9 | ||
|
|
||
| # disable vae tiling | ||
| pipe.disable_vae_tiling() | ||
| generator = torch.Generator(device=torch_device).manual_seed(0) | ||
| with torch.autocast(torch_device): | ||
| output = pipe( | ||
| [prompt], | ||
| width=640, | ||
| height=640, | ||
| generator=generator, | ||
| guidance_scale=7.5, | ||
| num_inference_steps=2, | ||
| output_type="numpy", | ||
| ) | ||
| image = output.images | ||
|
|
||
| # make sure that more than 4 GB is allocated | ||
| mem_bytes = torch.cuda.max_memory_allocated() | ||
| assert mem_bytes > 4e9 | ||
| assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 | ||
|
|
||
| def test_stable_diffusion_fp16_vs_autocast(self): | ||
| # this test makes sure that the original model with autocast | ||
| # and the new model with fp16 yield the same result | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, what use case does tiled encoding fulfills? Also, do we need blending during the encoding phase? If we used this as an autoencoder of a large image, I would have thought that blending during decoding would be enough to avoid seams between the tiles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's just to save memory? Encoding a large image is pretty memory intensive no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I meant that we don't need encoding for inference, and I don't think we'll train with very large images. It was just out of curiosity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did need it for something. IIRC img2img with large images.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you did img2img on huge images; cool, understood.