Skip to content

Commit 3dcc75c

Browse files
NarsilNouamaneTazi
andauthored
Removing autocast for 35-25% speedup. (autocast considered harmful). (#511)
* Removing `autocast` for `35-25% speedup`. * iQuality * Adding a slow test. * Fixing mps noise generation. * Raising error on wrong device, instead of just casting on behalf of user. * Quality. * fix merge Co-authored-by: Nouamane Tazi <[email protected]>
1 parent 6b09f37 commit 3dcc75c

File tree

6 files changed

+60
-48
lines changed

6 files changed

+60
-48
lines changed

README.md

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff
7676

7777
```python
7878
# make sure you're logged in with `huggingface-cli login`
79-
from torch import autocast
8079
from diffusers import StableDiffusionPipeline
8180

8281
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
8382
pipe = pipe.to("cuda")
8483

8584
prompt = "a photo of an astronaut riding a horse on mars"
86-
with autocast("cuda"):
87-
image = pipe(prompt).images[0]
85+
image = pipe(prompt).images[0]
8886
```
8987

9088
**Note**: If you don't want to use the token, you can also simply download the model weights
@@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
104102
pipe = pipe.to("cuda")
105103

106104
prompt = "a photo of an astronaut riding a horse on mars"
107-
with autocast("cuda"):
108-
image = pipe(prompt).images[0]
105+
image = pipe(prompt).images[0]
109106
```
110107

111108
If you are limited by GPU memory, you might want to consider using the model in `fp16` as
@@ -123,8 +120,7 @@ pipe = pipe.to("cuda")
123120

124121
prompt = "a photo of an astronaut riding a horse on mars"
125122
pipe.enable_attention_slicing()
126-
with autocast("cuda"):
127-
image = pipe(prompt).images[0]
123+
image = pipe(prompt).images[0]
128124
```
129125

130126
Finally, if you wish to use a different scheduler, you can simply instantiate
@@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
149145
pipe = pipe.to("cuda")
150146

151147
prompt = "a photo of an astronaut riding a horse on mars"
152-
with autocast("cuda"):
153-
image = pipe(prompt).images[0]
148+
image = pipe(prompt).images[0]
154149

155150
image.save("astronaut_rides_horse.png")
156151
```
@@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png")
160155
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
161156

162157
```python
163-
from torch import autocast
164158
import requests
165159
import torch
166160
from PIL import Image
@@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512))
190184

191185
prompt = "A fantasy landscape, trending on artstation"
192186

193-
with autocast("cuda"):
194-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
187+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
195188

196189
images[0].save("fantasy_landscape.png")
197190
```
@@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
204197
```python
205198
from io import BytesIO
206199

207-
from torch import autocast
208200
import torch
209201
import requests
210202
import PIL
@@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
234226
pipe = pipe.to(device)
235227

236228
prompt = "a cat sitting on a bench"
237-
with autocast("cuda"):
238-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
229+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
239230

240231
images[0].save("cat_on_bench.png")
241232
```
@@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out:
258249
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
259250
```python
260251
# !pip install diffusers transformers
261-
from torch import autocast
262252
from diffusers import DiffusionPipeline
263253

264254
device = "cuda"
@@ -270,16 +260,14 @@ ldm = ldm.to(device)
270260

271261
# run pipeline in inference (sample random noise and denoise)
272262
prompt = "A painting of a squirrel eating a burger"
273-
with autocast(device):
274-
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
263+
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
275264

276265
# save image
277266
image.save("squirrel.png")
278267
```
279268
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
280269
```python
281270
# !pip install diffusers
282-
from torch import autocast
283271
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
284272

285273
model_id = "google/ddpm-celebahq-256"
@@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
290278
ddpm.to(device)
291279

292280
# run pipeline in inference (sample random noise and denoise)
293-
with autocast("cuda"):
294-
image = ddpm().images[0]
281+
image = ddpm().images[0]
295282

296283
# save image
297284
image.save("ddpm_generated_image.png")

src/diffusers/models/unet_2d_condition.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,12 @@ def forward(
266266
timesteps = timesteps.expand(sample.shape[0])
267267

268268
t_emb = self.time_proj(timesteps)
269-
emb = self.time_embedding(t_emb.to(self.dtype))
269+
270+
# timesteps does not contain any weights and will always return f32 tensors
271+
# but time_embedding might actually be running in fp16. so we need to cast here.
272+
# there might be better ways to encapsulate this.
273+
t_emb = t_emb.to(dtype=self.dtype)
274+
emb = self.time_embedding(t_emb)
270275

271276
# 2. pre-process
272277
sample = self.conv_in(sample)

src/diffusers/pipelines/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
8686

8787
```python
8888
# make sure you're logged in with `huggingface-cli login`
89-
from torch import autocast
9089
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
9190

9291
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
9392
pipe = pipe.to("cuda")
9493

9594
prompt = "a photo of an astronaut riding a horse on mars"
96-
with autocast("cuda"):
97-
image = pipe(prompt).images[0]
95+
image = pipe(prompt).images[0]
9896

9997
image.save("astronaut_rides_horse.png")
10098
```
@@ -104,7 +102,6 @@ image.save("astronaut_rides_horse.png")
104102
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
105103

106104
```python
107-
from torch import autocast
108105
import requests
109106
from PIL import Image
110107
from io import BytesIO
@@ -129,8 +126,7 @@ init_image = init_image.resize((768, 512))
129126

130127
prompt = "A fantasy landscape, trending on artstation"
131128

132-
with autocast("cuda"):
133-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
129+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
134130

135131
images[0].save("fantasy_landscape.png")
136132
```
@@ -148,7 +144,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
148144
```python
149145
from io import BytesIO
150146

151-
from torch import autocast
152147
import requests
153148
import PIL
154149

@@ -173,8 +168,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
173168
).to(device)
174169

175170
prompt = "a cat sitting on a bench"
176-
with autocast("cuda"):
177-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
171+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
178172

179173
images[0].save("cat_on_bench.png")
180174
```

src/diffusers/pipelines/stable_diffusion/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
5959

6060
```python
6161
# make sure you're logged in with `huggingface-cli login`
62-
from torch import autocast
6362
from diffusers import StableDiffusionPipeline
6463

6564
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
6665
pipe = pipe.to("cuda")
6766

6867
prompt = "a photo of an astronaut riding a horse on mars"
69-
with autocast("cuda"):
70-
image = pipe(prompt).images[0]
68+
image = pipe(prompt).sample[0]
7169

7270
image.save("astronaut_rides_horse.png")
7371
```
@@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")
7674

7775
```python
7876
# make sure you're logged in with `huggingface-cli login`
79-
from torch import autocast
8077
from diffusers import StableDiffusionPipeline, DDIMScheduler
8178

8279
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
@@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
8885
).to("cuda")
8986

9087
prompt = "a photo of an astronaut riding a horse on mars"
91-
with autocast("cuda"):
92-
image = pipe(prompt).images[0]
88+
image = pipe(prompt).sample[0]
9389

9490
image.save("astronaut_rides_horse.png")
9591
```
@@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")
9894

9995
```python
10096
# make sure you're logged in with `huggingface-cli login`
101-
from torch import autocast
10297
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
10398

10499
lms = LMSDiscreteScheduler(
@@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
114109
).to("cuda")
115110

116111
prompt = "a photo of an astronaut riding a horse on mars"
117-
with autocast("cuda"):
118-
image = pipe(prompt).images[0]
112+
image = pipe(prompt).sample[0]
119113

120114
image.save("astronaut_rides_horse.png")
121115
```

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,19 +260,20 @@ def __call__(
260260
# Unlike in other pipelines, latents need to be generated in the target device
261261
# for 1-to-1 results reproducibility with the CompVis implementation.
262262
# However this currently doesn't work in `mps`.
263-
latents_device = "cpu" if self.device.type == "mps" else self.device
264263
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
264+
latents_dtype = text_embeddings.dtype
265265
if latents is None:
266-
latents = torch.randn(
267-
latents_shape,
268-
generator=generator,
269-
device=latents_device,
270-
dtype=text_embeddings.dtype,
271-
)
266+
if self.device.type == "mps":
267+
# randn does not exist on mps
268+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
269+
self.device
270+
)
271+
else:
272+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
272273
else:
273274
if latents.shape != latents_shape:
274275
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
275-
latents = latents.to(latents_device)
276+
latents = latents.to(self.device)
276277

277278
# set timesteps
278279
self.scheduler.set_timesteps(num_inference_steps)

tests/test_pipelines.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,37 @@ def test_stable_diffusion_memory_chunking(self):
12141214
assert mem_bytes > 3.75 * 10**9
12151215
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
12161216

1217+
@slow
1218+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
1219+
def test_stable_diffusion_text2img_pipeline_fp16(self):
1220+
torch.cuda.reset_peak_memory_stats()
1221+
model_id = "CompVis/stable-diffusion-v1-4"
1222+
pipe = StableDiffusionPipeline.from_pretrained(
1223+
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
1224+
).to(torch_device)
1225+
pipe.set_progress_bar_config(disable=None)
1226+
1227+
prompt = "a photograph of an astronaut riding a horse"
1228+
1229+
generator = torch.Generator(device=torch_device).manual_seed(0)
1230+
output_chunked = pipe(
1231+
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
1232+
)
1233+
image_chunked = output_chunked.images
1234+
1235+
generator = torch.Generator(device=torch_device).manual_seed(0)
1236+
with torch.autocast(torch_device):
1237+
output = pipe(
1238+
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
1239+
)
1240+
image = output.images
1241+
1242+
# Make sure results are close enough
1243+
diff = np.abs(image_chunked.flatten() - image.flatten())
1244+
# They ARE different since ops are not run always at the same precision
1245+
# however, they should be extremely close.
1246+
assert diff.mean() < 2e-2
1247+
12171248
@slow
12181249
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
12191250
def test_stable_diffusion_text2img_pipeline(self):

0 commit comments

Comments
 (0)