@@ -318,23 +318,24 @@ def __call__(
318318        _ , k_tokens , _  =  key .shape 
319319        qk_matmul_size_bytes  =  batch_x_heads  *  bytes_per_token  *  q_tokens  *  k_tokens 
320320
321-         if  self .chunk_threshold_bytes  is  None  or  qk_matmul_size_bytes  >  self .chunk_threshold_bytes :
322-             hidden_states  =  efficient_dot_product_attention (
323-                 query ,
324-                 key ,
325-                 value ,
326-                 query_chunk_size = self .query_chunk_size ,
327-                 kv_chunk_size = self .kv_chunk_size ,
328-                 kv_chunk_size_min = self .kv_chunk_size_min ,
329-                 use_checkpoint = attn .training ,
330-             )
331-         else :
332-             # the big matmul fits into our memory limit; compute via unchunked attention (it's faster) 
333-             attention_probs  =  attn .get_attention_scores (
334-                 query ,
335-                 key ,
336-             )
337-             hidden_states  =  torch .bmm (attention_probs , value )
321+         query_chunk_size  =  self .query_chunk_size 
322+         kv_chunk_size  =  self .kv_chunk_size 
323+ 
324+         if  self .chunk_threshold_bytes  is  not None  and  qk_matmul_size_bytes  <=  self .chunk_threshold_bytes :
325+             # the big matmul fits into our memory limit; do everything in 1 chunk, 
326+             # i.e. send it down the unchunked fast-path 
327+             query_chunk_size  =  q_tokens 
328+             kv_chunk_size  =  k_tokens 
329+ 
330+         hidden_states  =  efficient_dot_product_attention (
331+             query ,
332+             key ,
333+             value ,
334+             query_chunk_size = query_chunk_size ,
335+             kv_chunk_size = kv_chunk_size ,
336+             kv_chunk_size_min = self .kv_chunk_size_min ,
337+             use_checkpoint = attn .training ,
338+         )
338339
339340        hidden_states  =  hidden_states .to (dtype )
340341
0 commit comments