From af21a6bc3417a2914bf3f91e791515d33088f1b9 Mon Sep 17 00:00:00 2001 From: Yuyan Peng Date: Thu, 30 Oct 2025 07:23:08 +0000 Subject: [PATCH 1/2] clean the experiment comment codes --- exp/custom_splash_attention.py | 467 +-------------------------------- 1 file changed, 3 insertions(+), 464 deletions(-) diff --git a/exp/custom_splash_attention.py b/exp/custom_splash_attention.py index c3fad1e42865..0871067793cf 100644 --- a/exp/custom_splash_attention.py +++ b/exp/custom_splash_attention.py @@ -75,9 +75,6 @@ def _flash_attention_kernel( raise NotImplementedError( f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}" ) - # head_dim_v_repeats, rem = divmod(head_dim_v, NUM_LANES) - # if rem != 0: - # raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_LANES}") h, i, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) @@ -87,450 +84,7 @@ def init(): m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) - ### - - # # with jax.named_scope("qk"): - # q = q_ref[...] - # k = k_ref[...] - - # qk_all = lax.dot_general(q, k, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk_all.shape == (bq, bkv) - - # step = bkv_compute - # assert step % NUM_LANES == 0 - # assert bkv % step == 0 - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for i in range(0, bkv, step): - # qk = qk_all[:,i:i+step] - # # qk = lax.dot_general(k[i:i+step], q, NT_DIM_NUMBERS, preferred_element_type=float32) - # # assert qk.shape == (step, bq) - # # with jax.named_scope("qk"): - # assert m_prev.shape == (bq, NUM_LANES) - # assert l_prev.shape == (bq, NUM_LANES) - - # # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=1)[:, None] - # assert m_curr.shape == (bq, 1) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (bq, NUM_LANES) - - # bkv_repeats, rem = divmod(bkv_compute, NUM_LANES) - # if rem != 0: - # raise NotImplementedError( - # f"{bkv_compute=} should be a multiple of {NUM_LANES}" - # ) - - # s_curr = jnp.exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1)) - # # assert s_curr.shape == (bq, bkv_compute) - # # # with jax.named_scope("qk_exp"): - # # s_diff = qk - m_next[:,0:1] - # # s_curr = jnp.exp(s_diff) - # assert s_curr.shape == (bq, step) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=1, keepdims=True) - # assert l_curr.shape == (bq, 1) - - # # with jax.named_scope("qk_alpha"): - # m_diff = m_prev - m_next - # alpha = jnp.exp(m_diff) - - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # # with jax.named_scope("qkv"): - # v = v_ref[i:i+step].astype(float32) - # sv_dims = (((1,), (0,)), ((), ())) - # o_curr = lax.dot_general(s_curr, v, sv_dims) - # # alpha_o = alpha[:, 0:1] - # alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1) - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - - # with jax.named_scope("qk"): - # q = q_ref[...] - # k = k_ref[...] - # qk_all = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk_all.shape == (bkv, bq) - # # qk_all = lax.dot_general(q, k, NT_DIM_NUMBERS, preferred_element_type=float32) - # # assert qk_all.shape == (bq, bkv) - - # step = bkv_compute - # assert step % NUM_SUBLANES == 0 - # assert bkv % step == 0 - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for i in range(0, bkv, step): - # qk = qk_all[i:i+step] - # # qk = lax.dot_general(k[i:i+step], q, NT_DIM_NUMBERS, preferred_element_type=float32) - # # assert qk.shape == (step, bq) - # # with jax.named_scope("qk"): - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # assert m_curr.shape == (1, bq) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_diff = qk - m_next[0:1] - # s_curr = jnp.exp(s_diff) - # assert s_curr.shape == (step, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # m_diff = m_prev - m_next - # alpha = jnp.exp(m_diff) - - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # # with jax.named_scope("qkv"): - # v = v_ref[i:i+step].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) # (head_dim, bk) @ (bk, bq) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for kv_compute_index in range(0, (bkv // bkv_compute)): - # # with jax.named_scope("qk"): - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - - # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # assert m_curr.shape == (1, bq) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_curr = jnp.exp(qk - m_next[0:1]) - # assert s_curr.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # with jax.named_scope("qkv"): - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - - # assert bkv % bkv_compute == 0 - # qk_next = None - # m_curr_next = None - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for kv_compute_index in range(0, (bkv // bkv_compute) + 1): - # # nonlocal qk_pre - # # with jax.named_scope("qk"): - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - # if kv_compute_index < (bkv // bkv_compute): - # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # assert m_curr.shape == (1, bq) - - # m_curr, m_curr_next = m_curr_next, m_curr - # qk_next, qk = qk, qk_next - # if kv_compute_index == 0: - # continue - - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_curr = jnp.exp(qk - m_next[0:1]) - # assert s_curr.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # slice_k = pl.ds((kv_compute_index - 1) * bkv_compute, bkv_compute) - - # with jax.named_scope("qkv"): - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - - # assert bkv % bkv_compute == 0 - # qk_next = None - # # def body(kv_compute_index, _): - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for kv_compute_index in range(0, (bkv // bkv_compute) + 1): - # # nonlocal qk_pre - # # with jax.named_scope("qk"): - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - # if kv_compute_index < (bkv // bkv_compute): - # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # qk_next, qk = qk, qk_next - # if kv_compute_index == 0: - # continue - - # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # assert m_curr.shape == (1, bq) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_curr = jnp.exp(qk - m_next[0:1]) - # assert s_curr.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # slice_k = pl.ds((kv_compute_index - 1) * bkv_compute, bkv_compute) - - # with jax.named_scope("qkv"): - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - - # assert bkv % bkv_compute == 0 - # qk_next = None - # s_curr_next = None - # alpha_next = None - # # def body(kv_compute_index, _): - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # for kv_compute_index in range(0, (bkv // bkv_compute) + 2): - # # nonlocal qk_pre - # # with jax.named_scope("qk"): - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - # if kv_compute_index < (bkv // bkv_compute): - # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # qk_next, qk = qk, qk_next - # if kv_compute_index == 0: - # continue - # if kv_compute_index < (bkv // bkv_compute) + 1: - # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # assert m_curr.shape == (1, bq) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_curr = jnp.exp(qk - m_next[0:1]) - # assert s_curr.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_prev, l_prev = m_next, l_next - - # s_curr, s_curr_next = s_curr_next, s_curr - # alpha, alpha_next = alpha_next, alpha - # if kv_compute_index == 1: - # continue - - # slice_k = pl.ds((kv_compute_index - 2) * bkv_compute, bkv_compute) - - # with jax.named_scope("qkv"): - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - ### - def body(kv_compute_index, _): - - # # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_max"): - # m_curr_list = [] - # s_curr_list = [] - # step = bkv_compute_in - # assert qk.shape[0] % step == 0 - # for i in range(0, qk.shape[0], step): - # m_curr = qk[i:i+step].max(axis=0)[None, :] - # # m_curr = qk[0:1] - # assert m_curr.shape == (1, bq) - # m_curr_list.append(m_curr) - - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # s_curr = jnp.exp(qk[i:i+step] - m_curr[0:1]) - # # assert s_curr.shape == (bkv_compute, bq) - # s_curr_list.append(s_curr) - - # m_curr = jnp.concatenate(m_curr_list, axis=0) - # m_curr = jnp.exp(m_curr - m_next[0:1]) - - # for i in range(len(s_curr_list)): - # s_curr_list[i] = s_curr_list[i] * m_curr[i:i+1] - - # s_curr = jnp.concatenate(s_curr_list, axis=0) - # assert s_curr.shape == (bkv_compute, bq) - - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - ### - - # # with jax.named_scope("qk"): - # slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) - # m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] - # assert m_prev.shape == (NUM_SUBLANES, bq) - # assert l_prev.shape == (NUM_SUBLANES, bq) - - # q = q_ref[...] - # k = k_ref[slice_k, :] - # qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - # assert qk.shape == (bkv_compute, bq) - - # # with jax.named_scope("softmax"): - # # with jax.named_scope("qk_max"): - # m_curr = qk.max(axis=0)[None, :] - # # m_curr = qk[-1:, :] - # # m_curr = qk[0:1, :] - # # m_ub = jnp.zeros((1,bq), dtype=float32) - # assert m_curr.shape == (1, bq) - # # with jax.named_scope("qk_maximum"): - # m_next = jnp.maximum(m_prev, m_curr) - # assert m_next.shape == (NUM_SUBLANES, bq) - - # # with jax.named_scope("qk_exp"): - # s_curr = jnp.exp(qk - m_next[0:1]) - # # s_curr = jnp.exp(qk - m_prev[0:1]) - # # s_curr = s_curr * jnp.exp(m_prev - m_next)[0:1] - # assert s_curr.shape == (bkv_compute, bq) - - # # with jax.named_scope("qk_sum"): - # l_curr = s_curr.sum(axis=0, keepdims=True) - # assert l_curr.shape == (1, bq) - - # # with jax.named_scope("qk_alpha"): - # alpha = jnp.exp(m_prev - m_next) - # l_next = l_curr + alpha * l_prev - # m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next - - # # with jax.named_scope("qkv"): - # v = v_ref[slice_k, :].astype(float32) - # sv_dims = (((0,), (0,)), ((), ())) - # o_curr = lax.dot_general(v, s_curr, sv_dims) - # alpha_o = alpha[0:1, ...] - # o_scratch_ref[:] = alpha_o * o_scratch_ref[:] + o_curr - - ### - # # with jax.named_scope("qk"): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] @@ -634,23 +188,6 @@ def v_index_map(h, i, j, *_): pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)), pl.BlockSpec((None, head_dim_v, bq), out_index_map), ] - # in_specs = [ - # pl.BlockSpec((None, bq, head_dim_qk), q_index_map), - # pl.BlockSpec((None, bkv, head_dim_qk), k_index_map), - # pl.BlockSpec((None, bkv, head_dim_v), v_index_map), - # ] - # out_shapes = [ - # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), - # jax.ShapeDtypeStruct((bq, NUM_LANES), jnp.float32), - # jax.ShapeDtypeStruct((bq, head_dim_v), jnp.float32), - # jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), q.dtype), - # ] - # out_specs = [ - # pl.BlockSpec((bq, NUM_LANES), lambda *_: (0, 0)), - # pl.BlockSpec((bq, NUM_LANES), lambda *_: (0, 0)), - # pl.BlockSpec((bq, head_dim_v), lambda *_: (0, 0)), - # pl.BlockSpec((None, bq, head_dim_v), out_index_map), - # ] grid_width = kv_seq_len // bkv grid = (num_q_heads, q_seq_len // bq, grid_width) @@ -708,7 +245,9 @@ def _tpu_splash_attention(query, key, value, mesh): def _attention_on_slices(q, k, v): scale_factor = 1.0 / math.sqrt(q.shape[-1]) - q = q * scale_factor + # fuse the ops of exp in softmax here + _LOG2_E = 1.44269504 + q = q * scale_factor * _LOG2_E def pad_to_multiple(x, multiple, axis): seq_len = x.shape[axis] From a1ce152e2957d43fba53bd803cfe2a105895b425 Mon Sep 17 00:00:00 2001 From: Yuyan Peng Date: Fri, 31 Oct 2025 06:29:44 +0000 Subject: [PATCH 2/2] handle padding in attention kernel It's hard to eliminate kv padding without segment id. Strip the padding inside the kernel. --- exp/custom_splash_attention.py | 84 +++++++++++++++++++++++++++++++--- exp/wan2p2_benchmark.py | 28 +++--------- 2 files changed, 84 insertions(+), 28 deletions(-) diff --git a/exp/custom_splash_attention.py b/exp/custom_splash_attention.py index 0871067793cf..c05048c881c9 100644 --- a/exp/custom_splash_attention.py +++ b/exp/custom_splash_attention.py @@ -68,6 +68,8 @@ def _flash_attention_kernel( bkv_compute: int, bkv_compute_in: int, head_dim_v: int, + q_seq_len: int, + kv_seq_len: int, ): float32 = jnp.float32 head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) @@ -84,17 +86,18 @@ def init(): m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) - def body(kv_compute_index, _): + def compute_body(kv_compute_index, _): # # with jax.named_scope("qk"): - slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] assert m_prev.shape == (NUM_SUBLANES, bq) assert l_prev.shape == (NUM_SUBLANES, bq) q = q_ref[...] + slice_k_len = bkv_compute + slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) k = k_ref[slice_k, :] qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) - assert qk.shape == (bkv_compute, bq) + assert qk.shape == (slice_k_len, bq) # with jax.named_scope("softmax_qkv"): o_prev = o_scratch_ref[:] @@ -130,9 +133,74 @@ def body(kv_compute_index, _): m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next o_scratch_ref[:] = o_prev + def last_compute_body(kv_compute_index): + # # with jax.named_scope("qk"): + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + assert m_prev.shape == (NUM_SUBLANES, bq) + assert l_prev.shape == (NUM_SUBLANES, bq) + + # We don't care about q padding since it doesn't matter and truncated afterward + # We care about kv padding + q = q_ref[...] + slice_k_len = kv_seq_len % bkv_compute + slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) + k = k_ref[slice_k, :] + qk = lax.dot_general(k, q, NT_DIM_NUMBERS, preferred_element_type=float32) + assert qk.shape == (slice_k_len, bq) + + # with jax.named_scope("softmax_qkv"): + o_prev = o_scratch_ref[:] + + v = v_ref[slice_k, :].astype(float32) + + m_curr = qk.max(axis=0)[None, :] + assert m_curr.shape == (1, bq) + + m_next = jnp.maximum(m_prev, m_curr) + assert m_next.shape == (NUM_SUBLANES, bq) + + # the exp two ops: vmul and vpow. Fuse the vmul outside of kernel. + s_curr = exp2(qk - m_next[0:1]) + # assert s_curr.shape == (bkv_compute, bq) + + l_curr = s_curr.sum(axis=0, keepdims=True) + assert l_curr.shape == (1, bq) + + alpha = jnp.exp2(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general(v, s_curr, sv_dims) + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev = m_next + l_prev = l_next + + m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + o_scratch_ref[:] = o_prev + + ### + assert bkv % bkv_compute == 0 + @pl.when(j != grid_width - 1) + def body(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True) - lax.fori_loop(0, (bkv // bkv_compute), body, None, unroll=True) + @pl.when(j == grid_width - 1) + def last_body(): + if kv_seq_len % bkv == 0: + iter_num = (bkv // bkv_compute) + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + # the last iter may contain padding. Separate the case + remain_kv_seq_len = kv_seq_len % bkv + iter_num = ((remain_kv_seq_len + bkv_compute - 1) // bkv_compute) + if remain_kv_seq_len % bkv_compute == 0: + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True) + last_compute_body(iter_num-1) @pl.when(j == grid_width - 1) def end(): @@ -188,8 +256,10 @@ def v_index_map(h, i, j, *_): pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)), pl.BlockSpec((None, head_dim_v, bq), out_index_map), ] - grid_width = kv_seq_len // bkv - grid = (num_q_heads, q_seq_len // bq, grid_width) + # kv_seq_len and q_seq_len are not padding. + grid_width = (kv_seq_len + bkv - 1) // bkv + grid_height = (q_seq_len + bq - 1) // bq + grid = (num_q_heads, grid_height, grid_width) all_out = pl.pallas_call( partial( @@ -201,6 +271,8 @@ def v_index_map(h, i, j, *_): bkv_compute=bkv_compute, bkv_compute_in=bkv_compute_in, head_dim_v=head_dim_v, + q_seq_len=q_seq_len, + kv_seq_len=kv_seq_len, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, diff --git a/exp/wan2p2_benchmark.py b/exp/wan2p2_benchmark.py index ba6b3d7c7c81..e9b6f076da72 100644 --- a/exp/wan2p2_benchmark.py +++ b/exp/wan2p2_benchmark.py @@ -306,35 +306,19 @@ def kernel_3d(q_3d, k_3d, v_3d): kv_seq_len = k_3d.shape[1] num_heads_on_device = q_3d.shape[0] - # self attention - if k_3d.shape[1] > 10000: - # Pad q, k, v to next multiple of BQSIZE/BKVSIZE - q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) - k_3d_padded, k_orig_len = pad_to_multiple(k_3d, BKVSIZE, axis=1) - v_3d_padded, v_orig_len = pad_to_multiple(v_3d, BKVSIZE, axis=1) - else: - # do not padding on kv in cross attention. kv length is 512 - q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) - k_3d_padded, k_orig_len = k_3d, k_3d.shape[1] - v_3d_padded, v_orig_len = v_3d, v_3d.shape[1] - - padded_q_seq_len = q_3d_padded.shape[1] - padded_kv_seq_len = k_3d_padded.shape[1] - block_sizes = splash_attention.BlockSizes( - block_q=min(BQSIZE, padded_q_seq_len), - block_kv=min(BKVSIZE, padded_kv_seq_len), - block_kv_compute=min(BKVCOMPUTESIZE, padded_kv_seq_len), + block_q=min(BQSIZE, q_seq_len), + block_kv=min(BKVSIZE, kv_seq_len), + block_kv_compute=min(BKVCOMPUTESIZE, kv_seq_len), ) splash_kernel = custom_splash_attention.make_splash_mha( block_sizes=block_sizes, bkv_compute_in=BKVCOMPUTEINSIZE ) - out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded).astype( - q_3d_padded.dtype + out = splash_kernel(q_3d, k_3d, v_3d).astype( + q_3d.dtype ) - # Remove padding if any out = jnp.swapaxes(out, 1, 2) - return out[:, :q_orig_len, ...] + return out # Map the kernel over the batch dimension. vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0)