diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6ad0af18c1c9..31d1f1ff418a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -286,6 +286,32 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) + self._use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads @@ -320,21 +346,26 @@ def forward(self, hidden_states): key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) + if self._use_memory_efficient_attention_xformers: + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = hidden_states.to(query_proj.dtype) + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states)