@@ -290,11 +290,19 @@ def __init__(
290290 self .rescale_output_factor = rescale_output_factor
291291 self .proj_attn = nn .Linear (channels , channels , 1 )
292292
293- def transpose_for_scores (self , projection : torch .Tensor ) -> torch .Tensor :
294- new_projection_shape = projection .size ()[:- 1 ] + (self .num_heads , - 1 )
295- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
296- new_projection = projection .view (new_projection_shape ).permute (0 , 2 , 1 , 3 )
297- return new_projection
293+ def reshape_heads_to_batch_dim (self , tensor ):
294+ batch_size , seq_len , dim = tensor .shape
295+ head_size = self .num_heads
296+ tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
297+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
298+ return tensor
299+
300+ def reshape_batch_dim_to_heads (self , tensor ):
301+ batch_size , seq_len , dim = tensor .shape
302+ head_size = self .num_heads
303+ tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
304+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
305+ return tensor
298306
299307 def forward (self , hidden_states ):
300308 residual = hidden_states
@@ -312,50 +320,28 @@ def forward(self, hidden_states):
312320
313321 scale = 1 / math .sqrt (self .channels / self .num_heads )
314322
315- # get scores
316- if self .num_heads > 1 :
317- query_states = self .transpose_for_scores (query_proj )
318- key_states = self .transpose_for_scores (key_proj )
319- value_states = self .transpose_for_scores (value_proj )
320-
321- # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
322- # or reformulate this into a 3D problem?
323- # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
324- # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
325- # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
326- attention_scores = torch .matmul (query_states , key_states .transpose (- 1 , - 2 )) * scale
327- else :
328- query_states , key_states , value_states = query_proj , key_proj , value_proj
329-
330- attention_scores = torch .baddbmm (
331- torch .empty (
332- query_states .shape [0 ],
333- query_states .shape [1 ],
334- key_states .shape [1 ],
335- dtype = query_states .dtype ,
336- device = query_states .device ,
337- ),
338- query_states ,
339- key_states .transpose (- 1 , - 2 ),
340- beta = 0 ,
341- alpha = scale ,
342- )
323+ query_proj = self .reshape_heads_to_batch_dim (query_proj )
324+ key_proj = self .reshape_heads_to_batch_dim (key_proj )
325+ value_proj = self .reshape_heads_to_batch_dim (value_proj )
343326
327+ attention_scores = torch .baddbmm (
328+ torch .empty (
329+ query_proj .shape [0 ],
330+ query_proj .shape [1 ],
331+ key_proj .shape [1 ],
332+ dtype = query_proj .dtype ,
333+ device = query_proj .device ,
334+ ),
335+ query_proj ,
336+ key_proj .transpose (- 1 , - 2 ),
337+ beta = 0 ,
338+ alpha = scale ,
339+ )
344340 attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
341+ hidden_states = torch .bmm (attention_probs , value_proj )
345342
346- # compute attention output
347- if self .num_heads > 1 :
348- # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
349- # or reformulate this into a 3D problem?
350- # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
351- # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
352- # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
353- hidden_states = torch .matmul (attention_probs , value_states )
354- hidden_states = hidden_states .permute (0 , 2 , 1 , 3 ).contiguous ()
355- new_hidden_states_shape = hidden_states .size ()[:- 2 ] + (self .channels ,)
356- hidden_states = hidden_states .view (new_hidden_states_shape )
357- else :
358- hidden_states = torch .bmm (attention_probs , value_states )
343+ # reshape hidden_states
344+ hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
359345
360346 # compute next hidden_states
361347 hidden_states = self .proj_attn (hidden_states )
0 commit comments