Skip to content

Commit 359a0e3

Browse files
kigIlmari Heikkinenpatil-suraj
authored andcommitted
Add xformers attention to VAE (huggingface#1507)
* Add xformers attention to VAE * Simplify VAE xformers code * Update src/diffusers/models/attention.py Co-authored-by: Ilmari Heikkinen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 74acc19 commit 359a0e3

File tree

1 file changed

+46
-15
lines changed

1 file changed

+46
-15
lines changed

src/diffusers/models/attention.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,32 @@ def __init__(
286286
self.rescale_output_factor = rescale_output_factor
287287
self.proj_attn = nn.Linear(channels, channels, 1)
288288

289+
self._use_memory_efficient_attention_xformers = False
290+
291+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
292+
if not is_xformers_available():
293+
raise ModuleNotFoundError(
294+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
295+
" xformers",
296+
name="xformers",
297+
)
298+
elif not torch.cuda.is_available():
299+
raise ValueError(
300+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
301+
" available for GPU "
302+
)
303+
else:
304+
try:
305+
# Make sure we can run the memory efficient attention
306+
_ = xformers.ops.memory_efficient_attention(
307+
torch.randn((1, 2, 40), device="cuda"),
308+
torch.randn((1, 2, 40), device="cuda"),
309+
torch.randn((1, 2, 40), device="cuda"),
310+
)
311+
except Exception as e:
312+
raise e
313+
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
314+
289315
def reshape_heads_to_batch_dim(self, tensor):
290316
batch_size, seq_len, dim = tensor.shape
291317
head_size = self.num_heads
@@ -320,21 +346,26 @@ def forward(self, hidden_states):
320346
key_proj = self.reshape_heads_to_batch_dim(key_proj)
321347
value_proj = self.reshape_heads_to_batch_dim(value_proj)
322348

323-
attention_scores = torch.baddbmm(
324-
torch.empty(
325-
query_proj.shape[0],
326-
query_proj.shape[1],
327-
key_proj.shape[1],
328-
dtype=query_proj.dtype,
329-
device=query_proj.device,
330-
),
331-
query_proj,
332-
key_proj.transpose(-1, -2),
333-
beta=0,
334-
alpha=scale,
335-
)
336-
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
337-
hidden_states = torch.bmm(attention_probs, value_proj)
349+
if self._use_memory_efficient_attention_xformers:
350+
# Memory efficient attention
351+
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
352+
hidden_states = hidden_states.to(query_proj.dtype)
353+
else:
354+
attention_scores = torch.baddbmm(
355+
torch.empty(
356+
query_proj.shape[0],
357+
query_proj.shape[1],
358+
key_proj.shape[1],
359+
dtype=query_proj.dtype,
360+
device=query_proj.device,
361+
),
362+
query_proj,
363+
key_proj.transpose(-1, -2),
364+
beta=0,
365+
alpha=scale,
366+
)
367+
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
368+
hidden_states = torch.bmm(attention_probs, value_proj)
338369

339370
# reshape hidden_states
340371
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)

0 commit comments

Comments
 (0)