Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 57 additions & 35 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,22 +284,52 @@ def forward(self, hidden_states):
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(self.channels / self.num_heads)

# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
if self.num_heads > 1:
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)

# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
Comment on lines +291 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to formulate as 3D problem, the same way we are doing in CrossAttention, by merging the batch and heads , before bmm and then splitting it again. For example

def reshape_heads_to_batch_dim(tensor, heads=2):
    batch_size, seq_len, dim = tensor.shape
    head_size = heads
    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
    return tensor

def reshape_batch_dim_to_heads(tensor, heads=2):
    batch_size, seq_len, dim = tensor.shape
    head_size = heads
    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
    # tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
    return tensor

query_states = reshape_heads_to_batch_dim(query_proj)
key_states = reshape_heads_to_batch_dim(key_proj)
value_states = reshape_heads_to_batch_dim(v_proj)

attention_scores = torch.baddbmm(
    torch.empty(
        query_states.shape[0],
        query_states.shape[1],
        key_states.shape[1],
        dtype=query_states.dtype,
        device=query_states.device,
    ),
    query_states,
    key_states.transpose(-1, -2),
    beta=0,
    alpha=1,
)

attention_scores = reshape_batch_dim_to_heads(attention_scores)

else:
query_states, key_states, value_states = query_proj, key_proj, value_proj

attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=scale,
)

attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)

# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)

hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
if self.num_heads > 1:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
Comment on lines +322 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above, should be possible to make it 3D using the logic in above comment.

else:
hidden_states = torch.bmm(attention_probs, value_states)

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
Expand Down Expand Up @@ -507,19 +537,17 @@ def forward(self, hidden_states, context=None, mask=None):
return hidden_states

def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
else:
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output

if query.device.type == "mps":
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)
hidden_states = torch.bmm(attention_probs, value)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
Expand All @@ -534,21 +562,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
* self.scale
)
else:
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
beta=0,
alpha=self.scale,
)
Comment on lines +565 to +571
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

attn_slice = attn_slice.softmax(dim=-1)
if query.device.type == "mps":
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice

Expand Down