Skip to content

Commit 520281a

Browse files
patil-surajThomas Capelle
authored andcommitted
simplyfy AttentionBlock (huggingface#1492)
1 parent 5ffefe8 commit 520281a

File tree

1 file changed

+32
-46
lines changed

1 file changed

+32
-46
lines changed

src/diffusers/models/attention.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,19 @@ def __init__(
290290
self.rescale_output_factor = rescale_output_factor
291291
self.proj_attn = nn.Linear(channels, channels, 1)
292292

293-
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
294-
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
295-
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
296-
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
297-
return new_projection
293+
def reshape_heads_to_batch_dim(self, tensor):
294+
batch_size, seq_len, dim = tensor.shape
295+
head_size = self.num_heads
296+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
297+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
298+
return tensor
299+
300+
def reshape_batch_dim_to_heads(self, tensor):
301+
batch_size, seq_len, dim = tensor.shape
302+
head_size = self.num_heads
303+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
304+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
305+
return tensor
298306

299307
def forward(self, hidden_states):
300308
residual = hidden_states
@@ -312,50 +320,28 @@ def forward(self, hidden_states):
312320

313321
scale = 1 / math.sqrt(self.channels / self.num_heads)
314322

315-
# get scores
316-
if self.num_heads > 1:
317-
query_states = self.transpose_for_scores(query_proj)
318-
key_states = self.transpose_for_scores(key_proj)
319-
value_states = self.transpose_for_scores(value_proj)
320-
321-
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
322-
# or reformulate this into a 3D problem?
323-
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
324-
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
325-
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
326-
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
327-
else:
328-
query_states, key_states, value_states = query_proj, key_proj, value_proj
329-
330-
attention_scores = torch.baddbmm(
331-
torch.empty(
332-
query_states.shape[0],
333-
query_states.shape[1],
334-
key_states.shape[1],
335-
dtype=query_states.dtype,
336-
device=query_states.device,
337-
),
338-
query_states,
339-
key_states.transpose(-1, -2),
340-
beta=0,
341-
alpha=scale,
342-
)
323+
query_proj = self.reshape_heads_to_batch_dim(query_proj)
324+
key_proj = self.reshape_heads_to_batch_dim(key_proj)
325+
value_proj = self.reshape_heads_to_batch_dim(value_proj)
343326

327+
attention_scores = torch.baddbmm(
328+
torch.empty(
329+
query_proj.shape[0],
330+
query_proj.shape[1],
331+
key_proj.shape[1],
332+
dtype=query_proj.dtype,
333+
device=query_proj.device,
334+
),
335+
query_proj,
336+
key_proj.transpose(-1, -2),
337+
beta=0,
338+
alpha=scale,
339+
)
344340
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
341+
hidden_states = torch.bmm(attention_probs, value_proj)
345342

346-
# compute attention output
347-
if self.num_heads > 1:
348-
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
349-
# or reformulate this into a 3D problem?
350-
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
351-
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
352-
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
353-
hidden_states = torch.matmul(attention_probs, value_states)
354-
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
355-
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
356-
hidden_states = hidden_states.view(new_hidden_states_shape)
357-
else:
358-
hidden_states = torch.bmm(attention_probs, value_states)
343+
# reshape hidden_states
344+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
359345

360346
# compute next hidden_states
361347
hidden_states = self.proj_attn(hidden_states)

0 commit comments

Comments
 (0)