@@ -68,6 +68,8 @@ def _flash_attention_kernel(
6868 bkv_compute : int ,
6969 bkv_compute_in : int ,
7070 head_dim_v : int ,
71+ q_seq_len : int ,
72+ kv_seq_len : int ,
7173):
7274 float32 = jnp .float32
7375 head_dim_v_repeats , rem = divmod (head_dim_v , NUM_SUBLANES )
@@ -84,17 +86,18 @@ def init():
8486 m_scratch_ref [...] = jnp .full_like (m_scratch_ref , mask_value )
8587 l_scratch_ref [...] = jnp .zeros_like (l_scratch_ref )
8688
87- def body (kv_compute_index , _ ):
89+ def compute_body (kv_compute_index , _ ):
8890 # # with jax.named_scope("qk"):
89- slice_k = pl .ds (kv_compute_index * bkv_compute , bkv_compute )
9091 m_prev , l_prev = m_scratch_ref [...], l_scratch_ref [...]
9192 assert m_prev .shape == (NUM_SUBLANES , bq )
9293 assert l_prev .shape == (NUM_SUBLANES , bq )
9394
9495 q = q_ref [...]
96+ slice_k_len = bkv_compute
97+ slice_k = pl .ds (kv_compute_index * bkv_compute , slice_k_len )
9598 k = k_ref [slice_k , :]
9699 qk = lax .dot_general (k , q , NT_DIM_NUMBERS , preferred_element_type = float32 )
97- assert qk .shape == (bkv_compute , bq )
100+ assert qk .shape == (slice_k_len , bq )
98101
99102 # with jax.named_scope("softmax_qkv"):
100103 o_prev = o_scratch_ref [:]
@@ -130,9 +133,74 @@ def body(kv_compute_index, _):
130133 m_scratch_ref [...], l_scratch_ref [...] = m_next , l_next
131134 o_scratch_ref [:] = o_prev
132135
136+ def last_compute_body (kv_compute_index ):
137+ # # with jax.named_scope("qk"):
138+ m_prev , l_prev = m_scratch_ref [...], l_scratch_ref [...]
139+ assert m_prev .shape == (NUM_SUBLANES , bq )
140+ assert l_prev .shape == (NUM_SUBLANES , bq )
141+
142+ # We don't care about q padding since it doesn't matter and truncated afterward
143+ # We care about kv padding
144+ q = q_ref [...]
145+ slice_k_len = kv_seq_len % bkv_compute
146+ slice_k = pl .ds (kv_compute_index * bkv_compute , slice_k_len )
147+ k = k_ref [slice_k , :]
148+ qk = lax .dot_general (k , q , NT_DIM_NUMBERS , preferred_element_type = float32 )
149+ assert qk .shape == (slice_k_len , bq )
150+
151+ # with jax.named_scope("softmax_qkv"):
152+ o_prev = o_scratch_ref [:]
153+
154+ v = v_ref [slice_k , :].astype (float32 )
155+
156+ m_curr = qk .max (axis = 0 )[None , :]
157+ assert m_curr .shape == (1 , bq )
158+
159+ m_next = jnp .maximum (m_prev , m_curr )
160+ assert m_next .shape == (NUM_SUBLANES , bq )
161+
162+ # the exp two ops: vmul and vpow. Fuse the vmul outside of kernel.
163+ s_curr = exp2 (qk - m_next [0 :1 ])
164+ # assert s_curr.shape == (bkv_compute, bq)
165+
166+ l_curr = s_curr .sum (axis = 0 , keepdims = True )
167+ assert l_curr .shape == (1 , bq )
168+
169+ alpha = jnp .exp2 (m_prev - m_next )
170+ l_next = l_curr + alpha * l_prev
171+
172+ sv_dims = (((0 ,), (0 ,)), ((), ()))
173+ o_curr = lax .dot_general (v , s_curr , sv_dims )
174+ alpha_o = alpha [0 :1 , ...]
175+ o_prev = alpha_o * o_prev + o_curr
176+
177+ m_prev = m_next
178+ l_prev = l_next
179+
180+ m_scratch_ref [...], l_scratch_ref [...] = m_next , l_next
181+ o_scratch_ref [:] = o_prev
182+
183+
133184 ###
185+ assert bkv % bkv_compute == 0
186+ @pl .when (j != grid_width - 1 )
187+ def body ():
188+ lax .fori_loop (0 , (bkv // bkv_compute ), compute_body , None , unroll = True )
134189
135- lax .fori_loop (0 , (bkv // bkv_compute ), body , None , unroll = True )
190+ @pl .when (j == grid_width - 1 )
191+ def last_body ():
192+ if kv_seq_len % bkv == 0 :
193+ iter_num = (bkv // bkv_compute )
194+ lax .fori_loop (0 , iter_num , compute_body , None , unroll = True )
195+ else :
196+ # the last iter may contain padding. Separate the case
197+ remain_kv_seq_len = kv_seq_len % bkv
198+ iter_num = ((remain_kv_seq_len + bkv - 1 ) // bkv )
199+ if remain_kv_seq_len % bkv_compute == 0 :
200+ lax .fori_loop (0 , iter_num , compute_body , None , unroll = True )
201+ else :
202+ lax .fori_loop (0 , iter_num - 1 , compute_body , None , unroll = True )
203+ last_compute_body (iter_num - 1 )
136204
137205 @pl .when (j == grid_width - 1 )
138206 def end ():
@@ -188,8 +256,10 @@ def v_index_map(h, i, j, *_):
188256 pl .BlockSpec ((head_dim_v , bq ), lambda * _ : (0 , 0 )),
189257 pl .BlockSpec ((None , head_dim_v , bq ), out_index_map ),
190258 ]
191- grid_width = kv_seq_len // bkv
192- grid = (num_q_heads , q_seq_len // bq , grid_width )
259+ # kv_seq_len and q_seq_len are not padding.
260+ grid_width = (kv_seq_len + bkv - 1 ) // bkv
261+ grid_height = (q_seq_len + bq - 1 ) // bq
262+ grid = (num_q_heads , grid_height , grid_width )
193263
194264 all_out = pl .pallas_call (
195265 partial (
@@ -201,6 +271,8 @@ def v_index_map(h, i, j, *_):
201271 bkv_compute = bkv_compute ,
202272 bkv_compute_in = bkv_compute_in ,
203273 head_dim_v = head_dim_v ,
274+ q_seq_len = q_seq_len ,
275+ kv_seq_len = kv_seq_len ,
204276 ),
205277 grid_spec = pltpu .PrefetchScalarGridSpec (
206278 num_scalar_prefetch = 0 ,
0 commit comments