From f9bf1481763f53d1c24b1e42b6f36a75ce8186df Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 9 Nov 2022 00:49:49 +0000 Subject: [PATCH] perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1 --- src/diffusers/models/attention.py | 92 +++++++++++++++++++------------ 1 file changed, 57 insertions(+), 35 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index be9203b4d699..69522f76b09c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -284,22 +284,52 @@ def forward(self, hidden_states): key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) + scale = 1 / math.sqrt(self.channels / self.num_heads) # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + if self.num_heads > 1: + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale + else: + query_states, key_states, value_states = query_proj, key_proj, value_proj + + attention_scores = torch.baddbmm( + torch.empty( + query_states.shape[0], + query_states.shape[1], + key_states.shape[1], + dtype=query_states.dtype, + device=query_states.device, + ), + query_states, + key_states.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) + if self.num_heads > 1: + # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + hidden_states = torch.matmul(attention_probs, value_states) + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + else: + hidden_states = torch.bmm(attention_probs, value_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) @@ -507,19 +537,17 @@ def forward(self, hidden_states, context=None, mask=None): return hidden_states def _attention(self, query, key, value): - # TODO: use baddbmm for better performance - if query.device.type == "mps": - # Better performance on mps (~20-25%) - attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale - else: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attention_probs = attention_scores.softmax(dim=-1) # compute attention output - if query.device.type == "mps": - hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value) - else: - hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) @@ -534,21 +562,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - if query.device.type == "mps": - # Better performance on mps (~20-25%) - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) - * self.scale - ) - else: - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query[start_idx:end_idx], + key[start_idx:end_idx].transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attn_slice = attn_slice.softmax(dim=-1) - if query.device.type == "mps": - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) - else: - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice