@@ -284,22 +284,52 @@ def forward(self, hidden_states):
284284 key_proj = self .key (hidden_states )
285285 value_proj = self .value (hidden_states )
286286
287- # transpose
288- query_states = self .transpose_for_scores (query_proj )
289- key_states = self .transpose_for_scores (key_proj )
290- value_states = self .transpose_for_scores (value_proj )
287+ scale = 1 / math .sqrt (self .channels / self .num_heads )
291288
292289 # get scores
293- scale = 1 / math .sqrt (math .sqrt (self .channels / self .num_heads ))
294- attention_scores = torch .matmul (query_states * scale , key_states .transpose (- 1 , - 2 ) * scale ) # TODO: use baddmm
290+ if self .num_heads > 1 :
291+ query_states = self .transpose_for_scores (query_proj )
292+ key_states = self .transpose_for_scores (key_proj )
293+ value_states = self .transpose_for_scores (value_proj )
294+
295+ # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
296+ # or reformulate this into a 3D problem?
297+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
298+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
299+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
300+ attention_scores = torch .matmul (query_states , key_states .transpose (- 1 , - 2 )) * scale
301+ else :
302+ query_states , key_states , value_states = query_proj , key_proj , value_proj
303+
304+ attention_scores = torch .baddbmm (
305+ torch .empty (
306+ query_states .shape [0 ],
307+ query_states .shape [1 ],
308+ key_states .shape [1 ],
309+ dtype = query_states .dtype ,
310+ device = query_states .device ,
311+ ),
312+ query_states ,
313+ key_states .transpose (- 1 , - 2 ),
314+ beta = 0 ,
315+ alpha = scale ,
316+ )
317+
295318 attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
296319
297320 # compute attention output
298- hidden_states = torch .matmul (attention_probs , value_states )
299-
300- hidden_states = hidden_states .permute (0 , 2 , 1 , 3 ).contiguous ()
301- new_hidden_states_shape = hidden_states .size ()[:- 2 ] + (self .channels ,)
302- hidden_states = hidden_states .view (new_hidden_states_shape )
321+ if self .num_heads > 1 :
322+ # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
323+ # or reformulate this into a 3D problem?
324+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
325+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
326+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
327+ hidden_states = torch .matmul (attention_probs , value_states )
328+ hidden_states = hidden_states .permute (0 , 2 , 1 , 3 ).contiguous ()
329+ new_hidden_states_shape = hidden_states .size ()[:- 2 ] + (self .channels ,)
330+ hidden_states = hidden_states .view (new_hidden_states_shape )
331+ else :
332+ hidden_states = torch .bmm (attention_probs , value_states )
303333
304334 # compute next hidden_states
305335 hidden_states = self .proj_attn (hidden_states )
@@ -507,19 +537,17 @@ def forward(self, hidden_states, context=None, mask=None):
507537 return hidden_states
508538
509539 def _attention (self , query , key , value ):
510- # TODO: use baddbmm for better performance
511- if query .device .type == "mps" :
512- # Better performance on mps (~20-25%)
513- attention_scores = torch .einsum ("b i d, b j d -> b i j" , query , key ) * self .scale
514- else :
515- attention_scores = torch .matmul (query , key .transpose (- 1 , - 2 )) * self .scale
540+ attention_scores = torch .baddbmm (
541+ torch .empty (query .shape [0 ], query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device ),
542+ query ,
543+ key .transpose (- 1 , - 2 ),
544+ beta = 0 ,
545+ alpha = self .scale ,
546+ )
516547 attention_probs = attention_scores .softmax (dim = - 1 )
517548 # compute attention output
518549
519- if query .device .type == "mps" :
520- hidden_states = torch .einsum ("b i j, b j d -> b i d" , attention_probs , value )
521- else :
522- hidden_states = torch .matmul (attention_probs , value )
550+ hidden_states = torch .bmm (attention_probs , value )
523551
524552 # reshape hidden_states
525553 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
@@ -534,21 +562,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
534562 for i in range (hidden_states .shape [0 ] // slice_size ):
535563 start_idx = i * slice_size
536564 end_idx = (i + 1 ) * slice_size
537- if query .device .type == "mps" :
538- # Better performance on mps (~20-25%)
539- attn_slice = (
540- torch .einsum ("b i d, b j d -> b i j" , query [start_idx :end_idx ], key [start_idx :end_idx ])
541- * self .scale
542- )
543- else :
544- attn_slice = (
545- torch .matmul (query [start_idx :end_idx ], key [start_idx :end_idx ].transpose (1 , 2 )) * self .scale
546- ) # TODO: use baddbmm for better performance
565+ attn_slice = torch .baddbmm (
566+ torch .empty (slice_size , query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device ),
567+ query [start_idx :end_idx ],
568+ key [start_idx :end_idx ].transpose (- 1 , - 2 ),
569+ beta = 0 ,
570+ alpha = self .scale ,
571+ )
547572 attn_slice = attn_slice .softmax (dim = - 1 )
548- if query .device .type == "mps" :
549- attn_slice = torch .einsum ("b i j, b j d -> b i d" , attn_slice , value [start_idx :end_idx ])
550- else :
551- attn_slice = torch .matmul (attn_slice , value [start_idx :end_idx ])
573+ attn_slice = torch .bmm (attn_slice , value [start_idx :end_idx ])
552574
553575 hidden_states [start_idx :end_idx ] = attn_slice
554576
0 commit comments