@@ -286,6 +286,32 @@ def __init__(
286286 self .rescale_output_factor = rescale_output_factor
287287 self .proj_attn = nn .Linear (channels , channels , 1 )
288288
289+ self ._use_memory_efficient_attention_xformers = False
290+
291+ def set_use_memory_efficient_attention_xformers (self , use_memory_efficient_attention_xformers : bool ):
292+ if not is_xformers_available ():
293+ raise ModuleNotFoundError (
294+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
295+ " xformers" ,
296+ name = "xformers" ,
297+ )
298+ elif not torch .cuda .is_available ():
299+ raise ValueError (
300+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
301+ " available for GPU "
302+ )
303+ else :
304+ try :
305+ # Make sure we can run the memory efficient attention
306+ _ = xformers .ops .memory_efficient_attention (
307+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
308+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
309+ torch .randn ((1 , 2 , 40 ), device = "cuda" ),
310+ )
311+ except Exception as e :
312+ raise e
313+ self ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
314+
289315 def reshape_heads_to_batch_dim (self , tensor ):
290316 batch_size , seq_len , dim = tensor .shape
291317 head_size = self .num_heads
@@ -320,21 +346,26 @@ def forward(self, hidden_states):
320346 key_proj = self .reshape_heads_to_batch_dim (key_proj )
321347 value_proj = self .reshape_heads_to_batch_dim (value_proj )
322348
323- attention_scores = torch .baddbmm (
324- torch .empty (
325- query_proj .shape [0 ],
326- query_proj .shape [1 ],
327- key_proj .shape [1 ],
328- dtype = query_proj .dtype ,
329- device = query_proj .device ,
330- ),
331- query_proj ,
332- key_proj .transpose (- 1 , - 2 ),
333- beta = 0 ,
334- alpha = scale ,
335- )
336- attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
337- hidden_states = torch .bmm (attention_probs , value_proj )
349+ if self ._use_memory_efficient_attention_xformers :
350+ # Memory efficient attention
351+ hidden_states = xformers .ops .memory_efficient_attention (query_proj , key_proj , value_proj , attn_bias = None )
352+ hidden_states = hidden_states .to (query_proj .dtype )
353+ else :
354+ attention_scores = torch .baddbmm (
355+ torch .empty (
356+ query_proj .shape [0 ],
357+ query_proj .shape [1 ],
358+ key_proj .shape [1 ],
359+ dtype = query_proj .dtype ,
360+ device = query_proj .device ,
361+ ),
362+ query_proj ,
363+ key_proj .transpose (- 1 , - 2 ),
364+ beta = 0 ,
365+ alpha = scale ,
366+ )
367+ attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
368+ hidden_states = torch .bmm (attention_probs , value_proj )
338369
339370 # reshape hidden_states
340371 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
0 commit comments