Skip to content

Commit f9bf148

Browse files
committed
perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1
1 parent ab1f01e commit f9bf148

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

src/diffusers/models/attention.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -284,22 +284,52 @@ def forward(self, hidden_states):
284284
key_proj = self.key(hidden_states)
285285
value_proj = self.value(hidden_states)
286286

287-
# transpose
288-
query_states = self.transpose_for_scores(query_proj)
289-
key_states = self.transpose_for_scores(key_proj)
290-
value_states = self.transpose_for_scores(value_proj)
287+
scale = 1 / math.sqrt(self.channels / self.num_heads)
291288

292289
# get scores
293-
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
294-
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
290+
if self.num_heads > 1:
291+
query_states = self.transpose_for_scores(query_proj)
292+
key_states = self.transpose_for_scores(key_proj)
293+
value_states = self.transpose_for_scores(value_proj)
294+
295+
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
296+
# or reformulate this into a 3D problem?
297+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
298+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
299+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
300+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
301+
else:
302+
query_states, key_states, value_states = query_proj, key_proj, value_proj
303+
304+
attention_scores = torch.baddbmm(
305+
torch.empty(
306+
query_states.shape[0],
307+
query_states.shape[1],
308+
key_states.shape[1],
309+
dtype=query_states.dtype,
310+
device=query_states.device,
311+
),
312+
query_states,
313+
key_states.transpose(-1, -2),
314+
beta=0,
315+
alpha=scale,
316+
)
317+
295318
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
296319

297320
# compute attention output
298-
hidden_states = torch.matmul(attention_probs, value_states)
299-
300-
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
301-
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
302-
hidden_states = hidden_states.view(new_hidden_states_shape)
321+
if self.num_heads > 1:
322+
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
323+
# or reformulate this into a 3D problem?
324+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
325+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
326+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
327+
hidden_states = torch.matmul(attention_probs, value_states)
328+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
329+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
330+
hidden_states = hidden_states.view(new_hidden_states_shape)
331+
else:
332+
hidden_states = torch.bmm(attention_probs, value_states)
303333

304334
# compute next hidden_states
305335
hidden_states = self.proj_attn(hidden_states)
@@ -507,19 +537,17 @@ def forward(self, hidden_states, context=None, mask=None):
507537
return hidden_states
508538

509539
def _attention(self, query, key, value):
510-
# TODO: use baddbmm for better performance
511-
if query.device.type == "mps":
512-
# Better performance on mps (~20-25%)
513-
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
514-
else:
515-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
540+
attention_scores = torch.baddbmm(
541+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
542+
query,
543+
key.transpose(-1, -2),
544+
beta=0,
545+
alpha=self.scale,
546+
)
516547
attention_probs = attention_scores.softmax(dim=-1)
517548
# compute attention output
518549

519-
if query.device.type == "mps":
520-
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
521-
else:
522-
hidden_states = torch.matmul(attention_probs, value)
550+
hidden_states = torch.bmm(attention_probs, value)
523551

524552
# reshape hidden_states
525553
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
@@ -534,21 +562,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
534562
for i in range(hidden_states.shape[0] // slice_size):
535563
start_idx = i * slice_size
536564
end_idx = (i + 1) * slice_size
537-
if query.device.type == "mps":
538-
# Better performance on mps (~20-25%)
539-
attn_slice = (
540-
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
541-
* self.scale
542-
)
543-
else:
544-
attn_slice = (
545-
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
546-
) # TODO: use baddbmm for better performance
565+
attn_slice = torch.baddbmm(
566+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
567+
query[start_idx:end_idx],
568+
key[start_idx:end_idx].transpose(-1, -2),
569+
beta=0,
570+
alpha=self.scale,
571+
)
547572
attn_slice = attn_slice.softmax(dim=-1)
548-
if query.device.type == "mps":
549-
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
550-
else:
551-
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
573+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
552574

553575
hidden_states[start_idx:end_idx] = attn_slice
554576

0 commit comments

Comments
 (0)