Skip to content

Commit c0dd0e9

Browse files
committed
use baddbmm instead of matmulfor better in attention for better perf
1 parent cec5928 commit c0dd0e9

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/diffusers/models/attention.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ def forward(self, hidden_states):
7373
# get scores
7474
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
7575

76-
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
76+
attention_scores = torch.baddbmm(
77+
torch.empty(query_states.shape[0], query_states.shape[1], key_states.shape[1], dtype=query_states.dtype, device=query_states.device),
78+
query_states,
79+
key_states.transpose(-1, -2),
80+
beta=0,
81+
alpha=scale,
82+
)
7783
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
7884

7985
# compute attention output
@@ -272,7 +278,14 @@ def forward(self, hidden_states, context=None, mask=None):
272278
return self.to_out(hidden_states)
273279

274280
def _attention(self, query, key, value):
275-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
281+
# attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
282+
attention_scores = torch.baddbmm(
283+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
284+
query,
285+
key.transpose(-1, -2),
286+
beta=0,
287+
alpha=self.scale,
288+
)
276289
attention_probs = attention_scores.softmax(dim=-1)
277290
# compute attention output
278291
hidden_states = torch.matmul(attention_probs, value)
@@ -289,7 +302,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
289302
for i in range(hidden_states.shape[0] // slice_size):
290303
start_idx = i * slice_size
291304
end_idx = (i + 1) * slice_size
292-
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
305+
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale # TODO: use baddbmm for better performance
293306
attn_slice = attn_slice.softmax(dim=-1)
294307
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
295308

0 commit comments

Comments
 (0)