Skip to content

Commit a1ce152

Browse files
handle padding in attention kernel
It's hard to eliminate kv padding without segment id. Strip the padding inside the kernel.
1 parent af21a6b commit a1ce152

File tree

2 files changed

+84
-28
lines changed

2 files changed

+84
-28
lines changed

exp/custom_splash_attention.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def _flash_attention_kernel(
6868
bkv_compute: int,
6969
bkv_compute_in: int,
7070
head_dim_v: int,
71+
q_seq_len: int,
72+
kv_seq_len: int,
7173
):
7274
float32 = jnp.float32
7375
head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES)
@@ -84,17 +86,18 @@ def init():
8486
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
8587
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)
8688

87-
def body(kv_compute_index, _):
89+
def compute_body(kv_compute_index, _):
8890
# # with jax.named_scope("qk"):
89-
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
9091
m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...]
9192
assert m_prev.shape == (NUM_SUBLANES, bq)
9293
assert l_prev.shape == (NUM_SUBLANES, bq)
9394

9495
q = q_ref[...]
96+
slice_k_len = bkv_compute
97+
slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len)
9598
k = k_ref[slice_k, :]
9699
qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32)
97-
assert qk.shape == (bkv_compute, bq)
100+
assert qk.shape == (slice_k_len, bq)
98101

99102
# with jax.named_scope("softmax_qkv"):
100103
o_prev = o_scratch_ref[:]
@@ -130,9 +133,74 @@ def body(kv_compute_index, _):
130133
m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next
131134
o_scratch_ref[:] = o_prev
132135

136+
def last_compute_body(kv_compute_index):
137+
# # with jax.named_scope("qk"):
138+
m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...]
139+
assert m_prev.shape == (NUM_SUBLANES, bq)
140+
assert l_prev.shape == (NUM_SUBLANES, bq)
141+
142+
# We don't care about q padding since it doesn't matter and truncated afterward
143+
# We care about kv padding
144+
q = q_ref[...]
145+
slice_k_len = kv_seq_len % bkv_compute
146+
slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len)
147+
k = k_ref[slice_k, :]
148+
qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32)
149+
assert qk.shape == (slice_k_len, bq)
150+
151+
# with jax.named_scope("softmax_qkv"):
152+
o_prev = o_scratch_ref[:]
153+
154+
v = v_ref[slice_k, :].astype(float32)
155+
156+
m_curr = qk.max(axis=0)[None, :]
157+
assert m_curr.shape == (1, bq)
158+
159+
m_next = jnp.maximum(m_prev, m_curr)
160+
assert m_next.shape == (NUM_SUBLANES, bq)
161+
162+
# the exp two ops: vmul and vpow. Fuse the vmul outside of kernel.
163+
s_curr = exp2(qk - m_next[0:1])
164+
# assert s_curr.shape == (bkv_compute, bq)
165+
166+
l_curr = s_curr.sum(axis=0, keepdims=True)
167+
assert l_curr.shape == (1, bq)
168+
169+
alpha = jnp.exp2(m_prev - m_next)
170+
l_next = l_curr + alpha * l_prev
171+
172+
sv_dims = (((0,), (0,)), ((), ()))
173+
o_curr = lax.dot_general(v, s_curr, sv_dims)
174+
alpha_o = alpha[0:1, ...]
175+
o_prev = alpha_o * o_prev + o_curr
176+
177+
m_prev = m_next
178+
l_prev = l_next
179+
180+
m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next
181+
o_scratch_ref[:] = o_prev
182+
183+
133184
###
185+
assert bkv % bkv_compute == 0
186+
@pl.when(j != grid_width - 1)
187+
def body():
188+
lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True)
134189

135-
lax.fori_loop(0, (bkv // bkv_compute), body, None, unroll=True)
190+
@pl.when(j == grid_width - 1)
191+
def last_body():
192+
if kv_seq_len % bkv == 0:
193+
iter_num = (bkv // bkv_compute)
194+
lax.fori_loop(0, iter_num, compute_body, None, unroll=True)
195+
else:
196+
# the last iter may contain padding. Separate the case
197+
remain_kv_seq_len = kv_seq_len % bkv
198+
iter_num = ((remain_kv_seq_len + bkv_compute - 1) // bkv_compute)
199+
if remain_kv_seq_len % bkv_compute == 0:
200+
lax.fori_loop(0, iter_num, compute_body, None, unroll=True)
201+
else:
202+
lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True)
203+
last_compute_body(iter_num-1)
136204

137205
@pl.when(j == grid_width - 1)
138206
def end():
@@ -188,8 +256,10 @@ def v_index_map(h, i, j, *_):
188256
pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)),
189257
pl.BlockSpec((None, head_dim_v, bq), out_index_map),
190258
]
191-
grid_width = kv_seq_len // bkv
192-
grid = (num_q_heads, q_seq_len // bq, grid_width)
259+
# kv_seq_len and q_seq_len are not padding.
260+
grid_width = (kv_seq_len + bkv - 1) // bkv
261+
grid_height = (q_seq_len + bq - 1) // bq
262+
grid = (num_q_heads, grid_height, grid_width)
193263

194264
all_out = pl.pallas_call(
195265
partial(
@@ -201,6 +271,8 @@ def v_index_map(h, i, j, *_):
201271
bkv_compute=bkv_compute,
202272
bkv_compute_in=bkv_compute_in,
203273
head_dim_v=head_dim_v,
274+
q_seq_len=q_seq_len,
275+
kv_seq_len=kv_seq_len,
204276
),
205277
grid_spec=pltpu.PrefetchScalarGridSpec(
206278
num_scalar_prefetch=0,

exp/wan2p2_benchmark.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -306,35 +306,19 @@ def kernel_3d(q_3d, k_3d, v_3d):
306306
kv_seq_len = k_3d.shape[1]
307307
num_heads_on_device = q_3d.shape[0]
308308

309-
# self attention
310-
if k_3d.shape[1] > 10000:
311-
# Pad q, k, v to next multiple of BQSIZE/BKVSIZE
312-
q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1)
313-
k_3d_padded, k_orig_len = pad_to_multiple(k_3d, BKVSIZE, axis=1)
314-
v_3d_padded, v_orig_len = pad_to_multiple(v_3d, BKVSIZE, axis=1)
315-
else:
316-
# do not padding on kv in cross attention. kv length is 512
317-
q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1)
318-
k_3d_padded, k_orig_len = k_3d, k_3d.shape[1]
319-
v_3d_padded, v_orig_len = v_3d, v_3d.shape[1]
320-
321-
padded_q_seq_len = q_3d_padded.shape[1]
322-
padded_kv_seq_len = k_3d_padded.shape[1]
323-
324309
block_sizes = splash_attention.BlockSizes(
325-
block_q=min(BQSIZE, padded_q_seq_len),
326-
block_kv=min(BKVSIZE, padded_kv_seq_len),
327-
block_kv_compute=min(BKVCOMPUTESIZE, padded_kv_seq_len),
310+
block_q=min(BQSIZE, q_seq_len),
311+
block_kv=min(BKVSIZE, kv_seq_len),
312+
block_kv_compute=min(BKVCOMPUTESIZE, kv_seq_len),
328313
)
329314
splash_kernel = custom_splash_attention.make_splash_mha(
330315
block_sizes=block_sizes, bkv_compute_in=BKVCOMPUTEINSIZE
331316
)
332-
out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded).astype(
333-
q_3d_padded.dtype
317+
out = splash_kernel(q_3d, k_3d, v_3d).astype(
318+
q_3d.dtype
334319
)
335-
# Remove padding if any
336320
out = jnp.swapaxes(out, 1, 2)
337-
return out[:, :q_orig_len, ...]
321+
return out
338322

339323
# Map the kernel over the batch dimension.
340324
vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0)

0 commit comments

Comments
 (0)