@@ -73,7 +73,13 @@ def forward(self, hidden_states):
7373 # get scores
7474 scale = 1 / math .sqrt (math .sqrt (self .channels / self .num_heads ))
7575
76- attention_scores = torch .matmul (query_states * scale , key_states .transpose (- 1 , - 2 ) * scale )
76+ attention_scores = torch .baddbmm (
77+ torch .empty (query_states .shape [0 ], query_states .shape [1 ], key_states .shape [1 ], dtype = query_states .dtype , device = query_states .device ),
78+ query_states ,
79+ key_states .transpose (- 1 , - 2 ),
80+ beta = 0 ,
81+ alpha = scale ,
82+ )
7783 attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
7884
7985 # compute attention output
@@ -272,7 +278,14 @@ def forward(self, hidden_states, context=None, mask=None):
272278 return self .to_out (hidden_states )
273279
274280 def _attention (self , query , key , value ):
275- attention_scores = torch .matmul (query , key .transpose (- 1 , - 2 )) * self .scale
281+ # attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
282+ attention_scores = torch .baddbmm (
283+ torch .empty (query .shape [0 ], query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device ),
284+ query ,
285+ key .transpose (- 1 , - 2 ),
286+ beta = 0 ,
287+ alpha = self .scale ,
288+ )
276289 attention_probs = attention_scores .softmax (dim = - 1 )
277290 # compute attention output
278291 hidden_states = torch .matmul (attention_probs , value )
@@ -289,7 +302,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
289302 for i in range (hidden_states .shape [0 ] // slice_size ):
290303 start_idx = i * slice_size
291304 end_idx = (i + 1 ) * slice_size
292- attn_slice = torch .matmul (query [start_idx :end_idx ], key [start_idx :end_idx ].transpose (1 , 2 )) * self .scale
305+ attn_slice = torch .matmul (query [start_idx :end_idx ], key [start_idx :end_idx ].transpose (1 , 2 )) * self .scale # TODO: use baddbmm for better performance
293306 attn_slice = attn_slice .softmax (dim = - 1 )
294307 attn_slice = torch .matmul (attn_slice , value [start_idx :end_idx ])
295308
0 commit comments