Skip to content

Commit 70abbb7

Browse files
committed
Removing autocast for 35-25% speedup.
1 parent 83a7bb2 commit 70abbb7

File tree

6 files changed

+29
-71
lines changed

6 files changed

+29
-71
lines changed

README.md

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff
7676

7777
```python
7878
# make sure you're logged in with `huggingface-cli login`
79-
from torch import autocast
8079
from diffusers import StableDiffusionPipeline
8180

8281
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
8382
pipe = pipe.to("cuda")
8483

8584
prompt = "a photo of an astronaut riding a horse on mars"
86-
with autocast("cuda"):
87-
image = pipe(prompt).images[0]
85+
image = pipe(prompt).images[0]
8886
```
8987

9088
**Note**: If you don't want to use the token, you can also simply download the model weights
@@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
104102
pipe = pipe.to("cuda")
105103

106104
prompt = "a photo of an astronaut riding a horse on mars"
107-
with autocast("cuda"):
108-
image = pipe(prompt).images[0]
105+
image = pipe(prompt).images[0]
109106
```
110107

111108
If you are limited by GPU memory, you might want to consider using the model in `fp16` as
@@ -123,8 +120,7 @@ pipe = pipe.to("cuda")
123120

124121
prompt = "a photo of an astronaut riding a horse on mars"
125122
pipe.enable_attention_slicing()
126-
with autocast("cuda"):
127-
image = pipe(prompt).images[0]
123+
image = pipe(prompt).images[0]
128124
```
129125

130126
Finally, if you wish to use a different scheduler, you can simply instantiate
@@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
149145
pipe = pipe.to("cuda")
150146

151147
prompt = "a photo of an astronaut riding a horse on mars"
152-
with autocast("cuda"):
153-
image = pipe(prompt).images[0]
148+
image = pipe(prompt).images[0]
154149

155150
image.save("astronaut_rides_horse.png")
156151
```
@@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png")
160155
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
161156

162157
```python
163-
from torch import autocast
164158
import requests
165159
import torch
166160
from PIL import Image
@@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512))
190184

191185
prompt = "A fantasy landscape, trending on artstation"
192186

193-
with autocast("cuda"):
194-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
187+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
195188

196189
images[0].save("fantasy_landscape.png")
197190
```
@@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
204197
```python
205198
from io import BytesIO
206199

207-
from torch import autocast
208200
import torch
209201
import requests
210202
import PIL
@@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
234226
pipe = pipe.to(device)
235227

236228
prompt = "a cat sitting on a bench"
237-
with autocast("cuda"):
238-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
229+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
239230

240231
images[0].save("cat_on_bench.png")
241232
```
@@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out:
258249
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
259250
```python
260251
# !pip install diffusers transformers
261-
from torch import autocast
262252
from diffusers import DiffusionPipeline
263253

264254
device = "cuda"
@@ -270,16 +260,14 @@ ldm = ldm.to(device)
270260

271261
# run pipeline in inference (sample random noise and denoise)
272262
prompt = "A painting of a squirrel eating a burger"
273-
with autocast(device):
274-
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
263+
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
275264

276265
# save image
277266
image.save("squirrel.png")
278267
```
279268
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
280269
```python
281270
# !pip install diffusers
282-
from torch import autocast
283271
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
284272

285273
model_id = "google/ddpm-celebahq-256"
@@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi
290278
ddpm.to(device)
291279

292280
# run pipeline in inference (sample random noise and denoise)
293-
with autocast("cuda"):
294-
image = ddpm().images[0]
281+
image = ddpm().images[0]
295282

296283
# save image
297284
image.save("ddpm_generated_image.png")

src/diffusers/models/resnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,6 @@ def __init__(
331331
def forward(self, x, temb):
332332
hidden_states = x
333333

334-
# make sure hidden states is in float32
335-
# when running in half-precision
336334
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
337335
hidden_states = self.nonlinearity(hidden_states)
338336

@@ -349,8 +347,6 @@ def forward(self, x, temb):
349347
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350348
hidden_states = hidden_states + temb
351349

352-
# make sure hidden states is in float32
353-
# when running in half-precision
354350
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
355351
hidden_states = self.nonlinearity(hidden_states)
356352

src/diffusers/models/unet_2d_condition.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def forward(
222222
timesteps = timesteps.expand(sample.shape[0])
223223

224224
t_emb = self.time_proj(timesteps)
225+
226+
# timesteps does not contain any weights and will always return f32 tensors
227+
# but time_embedding might actually be running in fp16. so we need to cast here.
228+
# there might be better ways to encapsulate this.
229+
t_emb = t_emb.to(dtype=sample.dtype)
225230
emb = self.time_embedding(t_emb)
226231

227232
# 2. pre-process
@@ -258,9 +263,7 @@ def forward(
258263
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
259264

260265
# 6. post-process
261-
# make sure hidden states is in float32
262-
# when running in half-precision
263-
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
266+
sample = self.conv_norm_out(sample).type(sample.dtype)
264267
sample = self.conv_act(sample)
265268
sample = self.conv_out(sample)
266269

src/diffusers/pipelines/README.md

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
8686

8787
```python
8888
# make sure you're logged in with `huggingface-cli login`
89-
from torch import autocast
9089
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
9190

9291
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
9392
pipe = pipe.to("cuda")
9493

9594
prompt = "a photo of an astronaut riding a horse on mars"
96-
with autocast("cuda"):
97-
image = pipe(prompt).images[0]
95+
image = pipe(prompt).images[0]
9896

9997
image.save("astronaut_rides_horse.png")
10098
```
@@ -104,33 +102,7 @@ image.save("astronaut_rides_horse.png")
104102
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
105103

106104
```python
107-
from torch import autocast
108-
import requests
109-
from PIL import Image
110-
from io import BytesIO
111-
112-
from diffusers import StableDiffusionImg2ImgPipeline
113-
114-
# load the pipeline
115-
device = "cuda"
116-
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
117-
"CompVis/stable-diffusion-v1-4",
118-
revision="fp16",
119-
torch_dtype=torch.float16,
120-
use_auth_token=True
121-
).to(device)
122-
123-
# let's download an initial image
124-
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
125-
126-
response = requests.get(url)
127-
init_image = Image.open(BytesIO(response.content)).convert("RGB")
128-
init_image = init_image.resize((768, 512))
129-
130-
prompt = "A fantasy landscape, trending on artstation"
131-
132-
with autocast("cuda"):
133-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
105+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
134106

135107
images[0].save("fantasy_landscape.png")
136108
```
@@ -148,7 +120,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by
148120
```python
149121
from io import BytesIO
150122

151-
from torch import autocast
152123
import requests
153124
import PIL
154125

@@ -173,8 +144,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
173144
).to(device)
174145

175146
prompt = "a cat sitting on a bench"
176-
with autocast("cuda"):
177-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
147+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
178148

179149
images[0].save("cat_on_bench.png")
180150
```

src/diffusers/pipelines/stable_diffusion/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4")
5959

6060
```python
6161
# make sure you're logged in with `huggingface-cli login`
62-
from torch import autocast
6362
from diffusers import StableDiffusionPipeline
6463

6564
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
6665
pipe = pipe.to("cuda")
6766

6867
prompt = "a photo of an astronaut riding a horse on mars"
69-
with autocast("cuda"):
70-
image = pipe(prompt).sample[0]
68+
image = pipe(prompt).sample[0]
7169

7270
image.save("astronaut_rides_horse.png")
7371
```
@@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png")
7674

7775
```python
7876
# make sure you're logged in with `huggingface-cli login`
79-
from torch import autocast
8077
from diffusers import StableDiffusionPipeline, DDIMScheduler
8178

8279
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
@@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
8885
).to("cuda")
8986

9087
prompt = "a photo of an astronaut riding a horse on mars"
91-
with autocast("cuda"):
92-
image = pipe(prompt).sample[0]
88+
image = pipe(prompt).sample[0]
9389

9490
image.save("astronaut_rides_horse.png")
9591
```
@@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png")
9894

9995
```python
10096
# make sure you're logged in with `huggingface-cli login`
101-
from torch import autocast
10297
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
10398

10499
lms = LMSDiscreteScheduler(
@@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
114109
).to("cuda")
115110

116111
prompt = "a photo of an astronaut riding a horse on mars"
117-
with autocast("cuda"):
118-
image = pipe(prompt).sample[0]
112+
image = pipe(prompt).sample[0]
119113

120114
image.save("astronaut_rides_horse.png")
121115
```

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,18 @@ def __call__(
205205
# However this currently doesn't work in `mps`.
206206
latents_device = "cpu" if self.device.type == "mps" else self.device
207207
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
208+
latents_dtype = text_embeddings.dtype
208209
if latents is None:
209210
latents = torch.randn(
210211
latents_shape,
211212
generator=generator,
212213
device=latents_device,
214+
dtype=latents_dtype
213215
)
214216
else:
215217
if latents.shape != latents_shape:
216218
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
217-
latents = latents.to(self.device)
219+
latents = latents.to(self.device)
218220

219221
# set timesteps
220222
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
@@ -268,6 +270,12 @@ def __call__(
268270

269271
# run safety checker
270272
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
273+
274+
# XXX: it might be better to check against the actual dtype of the safety checker since
275+
# it might want to run in a different precision, but the safety checker does not expose
276+
# a `dtype` /`precision` itself, so this is a good enough proxy for running pipelines in
277+
# both f16 / f32
278+
safety_cheker_input.pixel_values = safety_cheker_input.pixel_values.to(dtype=latents_dtype)
271279
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
272280

273281
if output_type == "pil":

0 commit comments

Comments
 (0)