@@ -21,43 +21,41 @@ def flash_context_attention(
21
21
):
22
22
num_q_heads , dim = query_states .shape [1 :3 ]
23
23
num_kv_heads = value_states .shape [1 ]
24
- batch = q_start_loc .shape [0 ]
25
24
26
- for i in range (batch ):
27
- if torch .equal (q_seq_len [i ], kv_seq_len [i ]):
28
- ext_ops .context_attention (
29
- query_states ,
30
- key_states ,
31
- value_states ,
32
- q_start_loc [i :i + 1 ],
33
- q_seq_len [i :i + 1 ],
34
- num_q_heads ,
35
- num_kv_heads ,
36
- attn_mask = context .attention_mask [i :i + 1 ],
37
- attn_output = attn_output ,
38
- )
39
- else :
40
- key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
41
- value_cache = value_cache .reshape (1 , kv_cache_len ,
42
- num_kv_heads * dim )
43
- ext_ops .paged_prefill_attention (
44
- query_states ,
45
- key_cache ,
46
- value_cache ,
47
- block_offsets ,
48
- block_size ,
49
- q_start_loc [i :i + 1 ],
50
- q_seq_len [i :i + 1 ],
51
- kv_seq_len [i :i + 1 ],
52
- num_q_heads ,
53
- num_kv_heads ,
54
- attn_mask = context .attention_mask [i :i + 1 ],
55
- attn_output = attn_output ,
56
- )
25
+ if context .is_unpaged_prefill :
26
+ ext_ops .prefill_attention (
27
+ query_states ,
28
+ key_states ,
29
+ value_states ,
30
+ q_start_loc ,
31
+ q_seq_len ,
32
+ context .max_q_seq_length ,
33
+ num_q_heads ,
34
+ num_kv_heads ,
35
+ attn_mask = context .attention_mask ,
36
+ attn_output = attn_output ,
37
+ )
38
+ else :
39
+ key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
40
+ value_cache = value_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
41
+ ext_ops .paged_prefill_attention (
42
+ query_states ,
43
+ key_cache ,
44
+ value_cache ,
45
+ block_offsets ,
46
+ block_size ,
47
+ q_start_loc ,
48
+ q_seq_len ,
49
+ kv_seq_len ,
50
+ num_q_heads ,
51
+ num_kv_heads ,
52
+ attn_mask = context .attention_mask ,
53
+ attn_output = attn_output ,
54
+ )
57
55
58
56
59
57
def paged_token_attention (q , k_cache , v_cache , attn_output , kv_seq_len ,
60
- block_offsets , block_size ):
58
+ max_kv_seq_len , block_offsets , block_size ):
61
59
num_kv_heads , num_q_heads = k_cache .shape [1 ], q .shape [1 ]
62
60
ext_ops .paged_decode_attention (
63
61
q ,
@@ -66,6 +64,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
66
64
block_offsets ,
67
65
block_size ,
68
66
kv_seq_len ,
67
+ max_kv_seq_len ,
69
68
num_q_heads ,
70
69
num_kv_heads ,
71
70
attn_output = attn_output .view (q .shape ),
@@ -115,6 +114,7 @@ def paged_attention_fwd(
115
114
v ,
116
115
attn_output ,
117
116
kv_seqlens ,
117
+ context .max_kv_seq_length ,
118
118
block_offsets ,
119
119
block_size ,
120
120
)
0 commit comments