-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Removing autocast for 35-25% speedup. (autocast considered harmful).
#511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
70abbb7
6334170
e4f9388
6b853af
ce66d60
43c2d17
c207659
cf393eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -260,19 +260,20 @@ def __call__( | |
| # Unlike in other pipelines, latents need to be generated in the target device | ||
| # for 1-to-1 results reproducibility with the CompVis implementation. | ||
| # However this currently doesn't work in `mps`. | ||
| latents_device = "cpu" if self.device.type == "mps" else self.device | ||
| latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
| latents_dtype = text_embeddings.dtype | ||
| if latents is None: | ||
| latents = torch.randn( | ||
| latents_shape, | ||
| generator=generator, | ||
| device=latents_device, | ||
| dtype=text_embeddings.dtype, | ||
| ) | ||
| if self.device.type == "mps": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pcuenca could you take a look here?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I commented in the other conservation, I think it's ok like this. |
||
| # randn does not exist on mps | ||
| latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | ||
| self.device | ||
| ) | ||
| else: | ||
| latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
| else: | ||
| if latents.shape != latents_shape: | ||
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
| latents = latents.to(latents_device) | ||
| latents = latents.to(self.device) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # set timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1214,6 +1214,37 @@ def test_stable_diffusion_memory_chunking(self): | |
| assert mem_bytes > 3.75 * 10**9 | ||
| assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 | ||
|
|
||
| @slow | ||
| @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") | ||
| def test_stable_diffusion_text2img_pipeline_fp16(self): | ||
| torch.cuda.reset_peak_memory_stats() | ||
| model_id = "CompVis/stable-diffusion-v1-4" | ||
| pipe = StableDiffusionPipeline.from_pretrained( | ||
| model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True | ||
| ).to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| prompt = "a photograph of an astronaut riding a horse" | ||
|
|
||
| generator = torch.Generator(device=torch_device).manual_seed(0) | ||
| output_chunked = pipe( | ||
| [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" | ||
| ) | ||
| image_chunked = output_chunked.images | ||
|
|
||
| generator = torch.Generator(device=torch_device).manual_seed(0) | ||
| with torch.autocast(torch_device): | ||
NouamaneTazi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| output = pipe( | ||
| [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" | ||
| ) | ||
| image = output.images | ||
|
|
||
| # Make sure results are close enough | ||
| diff = np.abs(image_chunked.flatten() - image.flatten()) | ||
| # They ARE different since ops are not run always at the same precision | ||
| # however, they should be extremely close. | ||
| assert diff.mean() < 2e-2 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test tolerance is a bit high to me... => will play around with it a bit!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's roughly exactly the same difference as in #371 for pure Running some ops in f32 directly instead of f16 (which autocast will do) does change some patches. |
||
|
|
||
| @slow | ||
| @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") | ||
| def test_stable_diffusion_text2img_pipeline(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that this is faster? Using
autocastgives currently (before this PR) a 2x boost in terms of generation speed.Will also test a bit locally on a GPU tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is extremely surprising but I am also measuring a 2x speedup with autocast on
f32.I am looking into it, I see to copies but not nearly the same amount, there's probably a device-to-host /host-to-device somewhere that kills performance but I haven't found it yet.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I figure it out.
autocastwill actually usefp16for some ops by doing some heuristics. https://pytorch.org/docs/stable/amp.html#cuda-op-specific-behaviorSo it's faster because it's running on fp16 even if the model was loaded in f32.
So without it it's slower just because it's actually running f32.
If we enable real fp16 with a big performance boost I feel like we shouldn't need it f32 (but that does make it slower but also "more" correct.). Some ops are kind of dangerous to actual run in fp16 but we should be able to tell them apart (and for now it seems the generations are actually still good enough even when running everything in f16)
But it's still a nice way to get f16 running "for free". The heuristics they use seem OK (but as they mention, they probably wouldn't work for gradients, I guess because fp16 overflows faster)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented extensively here: #511
Should we maybe change the example to load native
fp16weights then?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So replace:
by
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems OK. Is
fp16supported by enough GPU cards at this point ? That would be my only point of concern.But given the speed difference, advocating for fp16 is definitely something we should do !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think fp16 is widely supported now