Skip to content

Commit 9ebaea5

Browse files
authored
Optimize Stable Diffusion (#371)
* initial commit * make UNet stream capturable * try to fix noise_pred value * remove cuda graph and keep NB * non blocking unet with PNDMScheduler * make timesteps np arrays for pndm scheduler because lists don't get formatted to tensors in `self.set_format` * make max async in pndm * use channel last format in unet * avoid moving timesteps device in each unet call * avoid memcpy op in `get_timestep_embedding` * add `channels_last` kwarg to `DiffusionPipeline.from_pretrained` * update TODO * replace `channels_last` kwarg with `memory_format` for more generality * revert the channels_last changes to leave it for another PR * remove non_blocking when moving input ids to device * remove blocking from all .to() operations at beginning of pipeline * fix merging * fix merging * model can run in other precisions without autocast * attn refactoring * Revert "attn refactoring" This reverts commit 0c70c0e. * remove restriction to run conv_norm in fp32 * use `baddbmm` instead of `matmul`for better in attention for better perf * removing all reshapes to test perf * Revert "removing all reshapes to test perf" This reverts commit 006ccb8. * add shapes comments * hardcore whats needed for jitting * Revert "hardcore whats needed for jitting" This reverts commit 2fa9c69. * Revert "remove restriction to run conv_norm in fp32" This reverts commit cec5928. * revert using baddmm in attention's forward * cleanup comment * remove restriction to run conv_norm in fp32. no quality loss was noticed This reverts commit cc9bc13. * add more optimizations techniques to docs * Revert "add shapes comments" This reverts commit 31c58ea. * apply suggestions * make quality * apply suggestions * styling * `scheduler.timesteps` are now arrays so we dont need .to() * remove useless .type() * use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms` * move scheduler timestamps to correct device if tensors * add device to `set_timesteps` in LMSD scheduler * `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it * quick fix * styling * remove kwargs from schedulers `set_timesteps` * revert to using max in K-LMS inpaint pipeline test * Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it" This reverts commit 00d5a51. * move timesteps to correct device before loop in SD pipeline * apply previous fix to other SD pipelines * UNet now accepts tensor timesteps even on wrong device, to avoid errors - it shouldnt affect performance if timesteps are alrdy on correct device - it does slow down performance if they're on the wrong device * fix pipeline when timesteps are arrays with strides
1 parent a7058f4 commit 9ebaea5

File tree

9 files changed

+244
-23
lines changed

9 files changed

+244
-23
lines changed

docs/source/optimization/fp16.mdx

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,64 @@ specific language governing permissions and limitations under the License.
1414

1515
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
1616

17-
## CUDA `autocast`
17+
<table>
18+
<tr>
19+
<td>
20+
<td>Latency
21+
<td>Speedup
22+
<tr>
23+
<tr>
24+
<td>original
25+
<td>9.50s
26+
<td>x1
27+
<tr>
28+
<tr>
29+
<td>cuDNN auto-tuner
30+
<td>9.37s
31+
<td>x1.01
32+
<tr>
33+
<td>autocast (fp16)
34+
<td>5.47s
35+
<td>x1.91
36+
<tr>
37+
<td>fp16
38+
<td>3.61s
39+
<td>x2.91
40+
<tr>
41+
<td>channels last
42+
<td>3.30s
43+
<td>x2.87
44+
<tr>
45+
<tr>
46+
<td>traced UNet
47+
<td>3.21s
48+
<td>x2.96
49+
</table>
50+
<em>obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.</em>
51+
52+
## Enable cuDNN auto-tuner
53+
54+
[NVIDIA cuDNN](https://developer.nvidia.com/cudnn) supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
55+
56+
Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
57+
58+
```python
59+
import torch
60+
61+
torch.backends.cudnn.benchmark = True
62+
```
63+
64+
### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
65+
66+
On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
67+
68+
```python
69+
import torch
70+
71+
torch.backends.cuda.matmul.allow_tf32 = True
72+
```
73+
74+
## Automatic mixed precision (AMP)
1875

1976
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
2077

@@ -47,7 +104,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
47104

48105
## Sliced attention for additional memory savings
49106

50-
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
107+
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
51108

52109
<Tip>
53110
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
@@ -73,4 +130,139 @@ with torch.autocast("cuda"):
73130
image = pipe(prompt).images[0]
74131
```
75132

76-
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
133+
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
134+
135+
## Using Channels Last memory format
136+
137+
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
138+
139+
For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
140+
141+
```python
142+
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
143+
pipe.unet.to(memory_format=torch.channels_last) # in-place operation
144+
print(
145+
pipe.unet.conv_out.state_dict()["weight"].stride()
146+
) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works
147+
```
148+
149+
## Tracing
150+
151+
Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
152+
153+
To trace our UNet model, we can use the following:
154+
155+
```python
156+
import time
157+
import torch
158+
from diffusers import StableDiffusionPipeline
159+
import functools
160+
161+
# torch disable grad
162+
torch.set_grad_enabled(False)
163+
164+
# set variables
165+
n_experiments = 2
166+
unet_runs_per_experiment = 50
167+
168+
# load inputs
169+
def generate_inputs():
170+
sample = torch.randn(2, 4, 64, 64).half().cuda()
171+
timestep = torch.rand(1).half().cuda() * 999
172+
encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
173+
return sample, timestep, encoder_hidden_states
174+
175+
176+
pipe = StableDiffusionPipeline.from_pretrained(
177+
"CompVis/stable-diffusion-v1-4",
178+
# scheduler=scheduler,
179+
use_auth_token=True,
180+
revision="fp16",
181+
torch_dtype=torch.float16,
182+
).to("cuda")
183+
unet = pipe.unet
184+
unet.eval()
185+
unet.to(memory_format=torch.channels_last) # use channels_last memory format
186+
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
187+
188+
# warmup
189+
for _ in range(3):
190+
with torch.inference_mode():
191+
inputs = generate_inputs()
192+
orig_output = unet(*inputs)
193+
194+
# trace
195+
print("tracing..")
196+
unet_traced = torch.jit.trace(unet, inputs)
197+
unet_traced.eval()
198+
print("done tracing")
199+
200+
201+
# warmup and optimize graph
202+
for _ in range(5):
203+
with torch.inference_mode():
204+
inputs = generate_inputs()
205+
orig_output = unet_traced(*inputs)
206+
207+
208+
# benchmarking
209+
with torch.inference_mode():
210+
for _ in range(n_experiments):
211+
torch.cuda.synchronize()
212+
start_time = time.time()
213+
for _ in range(unet_runs_per_experiment):
214+
orig_output = unet_traced(*inputs)
215+
torch.cuda.synchronize()
216+
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
217+
for _ in range(n_experiments):
218+
torch.cuda.synchronize()
219+
start_time = time.time()
220+
for _ in range(unet_runs_per_experiment):
221+
orig_output = unet(*inputs)
222+
torch.cuda.synchronize()
223+
print(f"unet inference took {time.time() - start_time:.2f} seconds")
224+
225+
# save the model
226+
unet_traced.save("unet_traced.pt")
227+
```
228+
229+
Then we can replace the `unet` attribute of the pipeline with the traced model like the following
230+
231+
```python
232+
from diffusers import StableDiffusionPipeline
233+
import torch
234+
from dataclasses import dataclass
235+
236+
237+
@dataclass
238+
class UNet2DConditionOutput:
239+
sample: torch.FloatTensor
240+
241+
242+
pipe = StableDiffusionPipeline.from_pretrained(
243+
"CompVis/stable-diffusion-v1-4",
244+
# scheduler=scheduler,
245+
use_auth_token=True,
246+
revision="fp16",
247+
torch_dtype=torch.float16,
248+
).to("cuda")
249+
250+
# use jitted unet
251+
unet_traced = torch.jit.load("unet_traced.pt")
252+
# del pipe.unet
253+
class TracedUNet(torch.nn.Module):
254+
def __init__(self):
255+
super().__init__()
256+
self.in_channels = pipe.unet.in_channels
257+
self.device = pipe.unet.device
258+
259+
def forward(self, latent_model_input, t, encoder_hidden_states):
260+
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
261+
return UNet2DConditionOutput(sample=sample)
262+
263+
264+
pipe.unet = TracedUNet()
265+
266+
with torch.inference_mode():
267+
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
268+
```

src/diffusers/models/attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def forward(self, hidden_states):
7272

7373
# get scores
7474
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
75-
76-
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
75+
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
7776
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
7877

7978
# compute attention output
@@ -275,7 +274,13 @@ def forward(self, hidden_states, context=None, mask=None):
275274
return self.to_out(hidden_states)
276275

277276
def _attention(self, query, key, value):
278-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
277+
attention_scores = torch.baddbmm(
278+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
279+
query,
280+
key.transpose(-1, -2),
281+
beta=0,
282+
alpha=self.scale,
283+
)
279284
attention_probs = attention_scores.softmax(dim=-1)
280285
# compute attention output
281286
hidden_states = torch.matmul(attention_probs, value)
@@ -292,7 +297,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
292297
for i in range(hidden_states.shape[0] // slice_size):
293298
start_idx = i * slice_size
294299
end_idx = (i + 1) * slice_size
295-
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
300+
attn_slice = (
301+
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
302+
) # TODO: use baddbmm for better performance
296303
attn_slice = attn_slice.softmax(dim=-1)
297304
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
298305

src/diffusers/models/embeddings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def get_timestep_embedding(
3737
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
3838

3939
half_dim = embedding_dim // 2
40-
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
40+
exponent = -math.log(max_period) * torch.arange(
41+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
42+
)
4143
exponent = exponent / (half_dim - downscale_freq_shift)
4244

43-
emb = torch.exp(exponent).to(device=timesteps.device)
45+
emb = torch.exp(exponent)
4446
emb = timesteps[:, None].float() * emb[None, :]
4547

4648
# scale embeddings

src/diffusers/models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def forward(self, x, temb):
331331

332332
# make sure hidden states is in float32
333333
# when running in half-precision
334-
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
334+
hidden_states = self.norm1(hidden_states)
335335
hidden_states = self.nonlinearity(hidden_states)
336336

337337
if self.upsample is not None:
@@ -349,7 +349,7 @@ def forward(self, x, temb):
349349

350350
# make sure hidden states is in float32
351351
# when running in half-precision
352-
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
352+
hidden_states = self.norm2(hidden_states)
353353
hidden_states = self.nonlinearity(hidden_states)
354354

355355
hidden_states = self.dropout(hidden_states)

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,16 @@ def forward(
230230
# 1. time
231231
timesteps = timestep
232232
if not torch.is_tensor(timesteps):
233+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
233234
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
234235
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
235-
timesteps = timesteps.to(dtype=torch.float32)
236-
timesteps = timesteps[None].to(device=sample.device)
236+
timesteps = timesteps[None].to(sample.device)
237237

238238
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
239239
timesteps = timesteps.expand(sample.shape[0])
240240

241241
t_emb = self.time_proj(timesteps)
242-
emb = self.time_embedding(t_emb)
242+
emb = self.time_embedding(t_emb.to(self.dtype))
243243

244244
# 2. pre-process
245245
sample = self.conv_in(sample)
@@ -279,7 +279,7 @@ def forward(
279279
# 6. post-process
280280
# make sure hidden states is in float32
281281
# when running in half-precision
282-
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
282+
sample = self.conv_norm_out(sample)
283283
sample = self.conv_act(sample)
284284
sample = self.conv_out(sample)
285285

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,23 @@ def __call__(
225225
latents_shape,
226226
generator=generator,
227227
device=latents_device,
228+
dtype=text_embeddings.dtype,
228229
)
229230
else:
230231
if latents.shape != latents_shape:
231232
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
232-
latents = latents.to(self.device)
233+
latents = latents.to(latents_device)
233234

234235
# set timesteps
235236
self.scheduler.set_timesteps(num_inference_steps)
236237

238+
# Some schedulers like PNDM have timesteps as arrays
239+
# It's more optimzed to move all timesteps to correct device beforehand
240+
if torch.is_tensor(self.scheduler.timesteps):
241+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
242+
else:
243+
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
244+
237245
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
238246
if isinstance(self.scheduler, LMSDiscreteScheduler):
239247
latents = latents * self.scheduler.sigmas[0]
@@ -247,7 +255,7 @@ def __call__(
247255
if accepts_eta:
248256
extra_step_kwargs["eta"] = eta
249257

250-
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
258+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
251259
# expand the latents if we are doing classifier free guidance
252260
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
253261
if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -278,7 +286,9 @@ def __call__(
278286

279287
# run safety checker
280288
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
281-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
289+
image, has_nsfw_concept = self.safety_checker(
290+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
291+
)
282292

283293
if output_type == "pil":
284294
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ def __call__(
265265
latents = init_latents
266266

267267
t_start = max(num_inference_steps - init_timestep + offset, 0)
268-
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
268+
# Some schedulers like PNDM have timesteps as arrays
269+
# It's more optimzed to move all timesteps to correct device beforehand
270+
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
271+
272+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
269273
t_index = t_start + i
270274

271275
# expand the latents if we are doing classifier free guidance

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ def __call__(
298298

299299
latents = init_latents
300300
t_start = max(num_inference_steps - init_timestep + offset, 0)
301-
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
301+
# Some schedulers like PNDM have timesteps as arrays
302+
# It's more optimzed to move all timesteps to correct device beforehand
303+
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
304+
305+
for i, t in tqdm(enumerate(timesteps_tensor)):
302306
t_index = t_start + i
303307
# expand the latents if we are doing classifier free guidance
304308
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

0 commit comments

Comments
 (0)