1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import warnings
1415from typing import Callable , Optional , Union
1516
1617import torch
@@ -72,7 +73,8 @@ def __init__(
7273 self .upcast_attention = upcast_attention
7374 self .upcast_softmax = upcast_softmax
7475
75- self .scale = dim_head ** - 0.5 if scale_qk else 1.0
76+ self .scale_qk = scale_qk
77+ self .scale = dim_head ** - 0.5 if self .scale_qk else 1.0
7678
7779 self .heads = heads
7880 # for slice_size > 0 the attention score computation
@@ -140,7 +142,7 @@ def __init__(
140142 # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
141143 if processor is None :
142144 processor = (
143- AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and scale_qk else AttnProcessor ()
145+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self . scale_qk else AttnProcessor ()
144146 )
145147 self .set_processor (processor )
146148
@@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers(
176178 "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
177179 " only available for GPU "
178180 )
181+ elif hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk :
182+ warnings .warn (
183+ "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
184+ "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0."
185+ )
179186 else :
180187 try :
181188 # Make sure we can run the memory efficient attention
@@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers(
229236 if hasattr (self .processor , "to_k_custom_diffusion" ):
230237 processor .to (self .processor .to_k_custom_diffusion .weight .device )
231238 else :
232- processor = AttnProcessor ()
239+ # set attention processor
240+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
241+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
242+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
243+ processor = (
244+ AttnProcessor2_0 ()
245+ if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk
246+ else AttnProcessor ()
247+ )
233248
234249 self .set_processor (processor )
235250
@@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size):
244259 elif self .added_kv_proj_dim is not None :
245260 processor = AttnAddedKVProcessor ()
246261 else :
247- processor = AttnProcessor ()
262+ # set attention processor
263+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
264+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
265+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
266+ processor = (
267+ AttnProcessor2_0 () if hasattr (F , "scaled_dot_product_attention" ) and self .scale_qk else AttnProcessor ()
268+ )
248269
249270 self .set_processor (processor )
250271
0 commit comments