Skip to content

Commit 8a34fc7

Browse files
NarsilNouamaneTazi
andauthored
Removing autocast for 35-25% speedup. (autocast considered harmful). (huggingface#511)
* Removing `autocast` for `35-25% speedup`. * iQuality * Adding a slow test. * Fixing mps noise generation. * Raising error on wrong device, instead of just casting on behalf of user. * Quality. * fix merge Co-authored-by: Nouamane Tazi <[email protected]>
1 parent cdaa3b3 commit 8a34fc7

File tree

4 files changed

+21
-27
lines changed

4 files changed

+21
-27
lines changed

models/unet_2d_condition.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,12 @@ def forward(
266266
timesteps = timesteps.expand(sample.shape[0])
267267

268268
t_emb = self.time_proj(timesteps)
269-
emb = self.time_embedding(t_emb.to(self.dtype))
269+
270+
# timesteps does not contain any weights and will always return f32 tensors
271+
# but time_embedding might actually be running in fp16. so we need to cast here.
272+
# there might be better ways to encapsulate this.
273+
t_emb = t_emb.to(dtype=self.dtype)
274+
emb = self.time_embedding(t_emb)
270275

271276
# 2. pre-process
272277
sample = self.conv_in(sample)

pipelines/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
8686

8787
```python
8888
# make sure you're logged in with `huggingface-cli login`
89-
from torch import autocast
9089
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
9190

9291
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
9392
pipe = pipe.to("cuda")
9493

9594
prompt = "a photo of an astronaut riding a horse on mars"
96-
with autocast("cuda"):
97-
image = pipe(prompt).images[0]
95+
image = pipe(prompt).images[0]
9896

9997
image.save("astronaut_rides_horse.png")
10098
```
@@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
104102
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
105103

106104
```python
107-
from torch import autocast
108105
import requests
109106
from PIL import Image
110107
from io import BytesIO
@@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512))
129126

130127
prompt = "A fantasy landscape, trending on artstation"
131128

132-
with autocast("cuda"):
133-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
129+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
134130

135131
images[0].save("fantasy_landscape.png")
136132
```
@@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
148144
```python
149145
from io import BytesIO
150146

151-
from torch import autocast
152147
import requests
153148
import PIL
154149

@@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
173168
).to(device)
174169

175170
prompt = "a cat sitting on a bench"
176-
with autocast("cuda"):
177-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
171+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
178172

179173
images[0].save("cat_on_bench.png")
180174
```

pipelines/stable_diffusion/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
5959

6060
```python
6161
# make sure you're logged in with `huggingface-cli login`
62-
from torch import autocast
6362
from diffusers import StableDiffusionPipeline
6463

6564
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
6665
pipe = pipe.to("cuda")
6766

6867
prompt = "a photo of an astronaut riding a horse on mars"
69-
with autocast("cuda"):
70-
image = pipe(prompt).images[0]
68+
image = pipe(prompt).sample[0]
7169

7270
image.save("astronaut_rides_horse.png")
7371
```
@@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")
7674

7775
```python
7876
# make sure you're logged in with `huggingface-cli login`
79-
from torch import autocast
8077
from diffusers import StableDiffusionPipeline, DDIMScheduler
8178

8279
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
@@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
8885
).to("cuda")
8986

9087
prompt = "a photo of an astronaut riding a horse on mars"
91-
with autocast("cuda"):
92-
image = pipe(prompt).images[0]
88+
image = pipe(prompt).sample[0]
9389

9490
image.save("astronaut_rides_horse.png")
9591
```
@@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")
9894

9995
```python
10096
# make sure you're logged in with `huggingface-cli login`
101-
from torch import autocast
10297
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
10398

10499
lms = LMSDiscreteScheduler(
@@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
114109
).to("cuda")
115110

116111
prompt = "a photo of an astronaut riding a horse on mars"
117-
with autocast("cuda"):
118-
image = pipe(prompt).images[0]
112+
image = pipe(prompt).sample[0]
119113

120114
image.save("astronaut_rides_horse.png")
121115
```

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,19 +260,20 @@ def __call__(
260260
# Unlike in other pipelines, latents need to be generated in the target device
261261
# for 1-to-1 results reproducibility with the CompVis implementation.
262262
# However this currently doesn't work in `mps`.
263-
latents_device = "cpu" if self.device.type == "mps" else self.device
264263
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
264+
latents_dtype = text_embeddings.dtype
265265
if latents is None:
266-
latents = torch.randn(
267-
latents_shape,
268-
generator=generator,
269-
device=latents_device,
270-
dtype=text_embeddings.dtype,
271-
)
266+
if self.device.type == "mps":
267+
# randn does not exist on mps
268+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
269+
self.device
270+
)
271+
else:
272+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
272273
else:
273274
if latents.shape != latents_shape:
274275
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
275-
latents = latents.to(latents_device)
276+
latents = latents.to(self.device)
276277

277278
# set timesteps
278279
self.scheduler.set_timesteps(num_inference_steps)

0 commit comments

Comments
 (0)