|
| 1 | +import functools |
| 2 | +from jax.experimental.pallas.ops.tpu import splash_attention |
| 3 | +from jax.experimental.shard_map import shard_map |
| 4 | +from jax.sharding import NamedSharding |
| 5 | +from jax.sharding import PartitionSpec as P |
| 6 | +from jax.sharding import Mesh |
| 7 | +from jax.experimental import mesh_utils |
| 8 | + |
| 9 | +import jax |
| 10 | +import jax.numpy as jnp |
| 11 | +import math |
| 12 | +import time |
| 13 | + |
| 14 | +# import ringattention_pallas_tpu_splash |
| 15 | +import custom_splash_attention |
| 16 | + |
| 17 | + |
| 18 | +# Copy from wan_tx_splash_attn.py |
| 19 | +@functools.partial( |
| 20 | + jax.jit, |
| 21 | + static_argnames=("mesh", "bqsize", "bkvsize", "bkvcomputesize", "bkvcomputesinize"), |
| 22 | +) |
| 23 | +def _tpu_splash_attention( |
| 24 | + query, |
| 25 | + key, |
| 26 | + value, |
| 27 | + mesh, |
| 28 | + bqsize, |
| 29 | + bkvsize, |
| 30 | + bkvcomputesize, |
| 31 | + bkvcomputesinize, |
| 32 | + scale=None, |
| 33 | + is_causal=False, |
| 34 | + window_size=None, |
| 35 | +): |
| 36 | + num_heads = query.shape[1] |
| 37 | + |
| 38 | + # The function that will be sharded across devices. |
| 39 | + def _attention_on_slices(q, k, v): |
| 40 | + |
| 41 | + # Scale the query tensor. This happens on each device with its slice of data. |
| 42 | + scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale |
| 43 | + q = q * scale_factor |
| 44 | + |
| 45 | + def pad_to_multiple2(x, multiple, axis): |
| 46 | + # For try pad outside |
| 47 | + return x, x.shape[axis] |
| 48 | + |
| 49 | + # Helper to pad to next multiple |
| 50 | + def pad_to_multiple(x, multiple, axis): |
| 51 | + seq_len = x.shape[axis] |
| 52 | + pad_len = (multiple - seq_len % multiple) % multiple |
| 53 | + if pad_len == 0: |
| 54 | + return x, seq_len |
| 55 | + pad_width = [(0, 0)] * x.ndim |
| 56 | + pad_width[axis] = (0, pad_len) |
| 57 | + return jnp.pad(x, pad_width), seq_len |
| 58 | + |
| 59 | + # This function operates on a single item from the batch. |
| 60 | + def kernel_3d(q_3d, k_3d, v_3d): |
| 61 | + q_seq_len = q_3d.shape[1] |
| 62 | + kv_seq_len = k_3d.shape[1] |
| 63 | + num_heads_on_device = q_3d.shape[0] |
| 64 | + |
| 65 | + # Pad q, k, v to next multiple of BQSIZE/BKVSIZE |
| 66 | + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, bqsize, axis=1) |
| 67 | + k_3d_padded, k_orig_len = pad_to_multiple(k_3d, bkvsize, axis=1) |
| 68 | + v_3d_padded, v_orig_len = pad_to_multiple(v_3d, bkvsize, axis=1) |
| 69 | + |
| 70 | + padded_q_seq_len = q_3d_padded.shape[1] |
| 71 | + padded_kv_seq_len = k_3d_padded.shape[1] |
| 72 | + |
| 73 | + block_sizes = splash_attention.BlockSizes( |
| 74 | + block_q=min(bqsize, padded_q_seq_len), |
| 75 | + block_kv=min(bkvsize, padded_kv_seq_len), |
| 76 | + block_kv_compute=min(bkvcomputesize, padded_kv_seq_len), |
| 77 | + ) |
| 78 | + splash_kernel = custom_splash_attention.make_splash_mha( |
| 79 | + block_sizes=block_sizes, bkv_compute_in=bkvcomputesinize |
| 80 | + ) |
| 81 | + out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded) |
| 82 | + # Remove padding if any |
| 83 | + out = jnp.swapaxes(out, 1, 2) |
| 84 | + return out[:, :q_orig_len, ...] |
| 85 | + |
| 86 | + # Map the kernel over the batch dimension. |
| 87 | + vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0) |
| 88 | + return vmapped_kernel(q, k, v) |
| 89 | + |
| 90 | + # Determine the partitioning spec based on the number of heads. |
| 91 | + if num_heads < mesh.size: |
| 92 | + # Replicated case for VAE. All devices get the full tensor. |
| 93 | + q_partition_spec = P() |
| 94 | + kv_partition_spec = P() |
| 95 | + else: |
| 96 | + # Sharded case for Transformer. Split along the heads axis. |
| 97 | + # Attn1 self attention, key length is long. |
| 98 | + if key.shape[2] > 10000: |
| 99 | + q_partition_spec = P("dp", "axis", "sp", None) |
| 100 | + kv_partition_spec = P("dp", "axis", None, None) |
| 101 | + else: |
| 102 | + # Attn2 which is cross attention, kv sequence is shorter. All gather the key value cost less. |
| 103 | + q_partition_spec = P("dp", None, ("axis", "sp"), None) |
| 104 | + kv_partition_spec = P("dp", None, None, None) |
| 105 | + |
| 106 | + # ALWAYS use shard_map. The partition_spec will control the behavior. |
| 107 | + sharded_fn = shard_map( |
| 108 | + _attention_on_slices, |
| 109 | + mesh=mesh, |
| 110 | + in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), |
| 111 | + out_specs=q_partition_spec, |
| 112 | + check_rep=False, |
| 113 | + ) |
| 114 | + out = sharded_fn(query, key, value) |
| 115 | + out = jax.lax.with_sharding_constraint(out, P("dp", None, ("axis", "sp"), None)) |
| 116 | + return out |
| 117 | + |
| 118 | + |
| 119 | +def main(): |
| 120 | + query = jnp.ones((1, 40, 75600, 128)) |
| 121 | + key = jnp.ones((1, 40, 75600, 128)) |
| 122 | + value = jnp.ones((1, 40, 75600, 128)) |
| 123 | + |
| 124 | + bqsizes = (1512,) |
| 125 | + |
| 126 | + # bqsizes = (600, 630, 675, 700, 720, 756, 840, 900, 945, 1008, 1050, 1080, 1200, 1260, 1350, 1400, 1512, 1575, 1680, 1800, 1890, 2100, 2160, 2520, 2700, 2800, 3024, 3150, 3600, 3780, 4200) |
| 127 | + bqsizes = range(2560, 4096, 256) |
| 128 | + bkvsizes = range(2560, 4096, 256) |
| 129 | + bkvcomputesizes = range(256, 4096, 256) |
| 130 | + # bkvcomputesinizes = range(64, 4096, 64) |
| 131 | + bkvcomputesinizes = range(256, 4096, 256) |
| 132 | + |
| 133 | + # bqsizes = list(range(512, 4096, 128)) |
| 134 | + # bkvsizes = (3072,) |
| 135 | + # bkvcomputesizes = (1024,) |
| 136 | + |
| 137 | + # BQSIZE = 2816 # 2240 # 3024 #2520 |
| 138 | + # BKVSIZE = 3840 |
| 139 | + # BKVCOMPUTESIZE = 256 |
| 140 | + |
| 141 | + # bqsizes = (512,) |
| 142 | + # bkvsizes = (2048,) |
| 143 | + # bkvcomputesizes = (256,) |
| 144 | + |
| 145 | + tp_dim = jax.device_count() |
| 146 | + dp_dim = 1 |
| 147 | + sp_dim = 1 |
| 148 | + print("sp, bqsize, bkvsize, bkvcomputesize, time (s), padded_key_size") |
| 149 | + while tp_dim >= 1: |
| 150 | + mesh_devices = mesh_utils.create_device_mesh( |
| 151 | + (tp_dim, dp_dim, sp_dim), allow_split_physical_axes=True |
| 152 | + ) |
| 153 | + mesh = Mesh(mesh_devices, ("axis", "dp", "sp")) |
| 154 | + |
| 155 | + query = jax.device_put( |
| 156 | + query, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) |
| 157 | + ) |
| 158 | + key = jax.device_put( |
| 159 | + key, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) |
| 160 | + ) |
| 161 | + value = jax.device_put( |
| 162 | + value, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None)) |
| 163 | + ) |
| 164 | + with mesh: |
| 165 | + for bqsize in bqsizes: |
| 166 | + for bkvsize in bkvsizes: |
| 167 | + for bkvcomputesize in bkvcomputesizes: |
| 168 | + for bkvcomputesinize in bkvcomputesinizes: |
| 169 | + if ( |
| 170 | + bkvsize < bkvcomputesize |
| 171 | + or bkvsize % bkvcomputesize != 0 |
| 172 | + ): |
| 173 | + continue |
| 174 | + |
| 175 | + if ( |
| 176 | + bkvcomputesize < bkvcomputesinize |
| 177 | + or bkvcomputesize % bkvcomputesinize != 0 |
| 178 | + ): |
| 179 | + continue |
| 180 | + |
| 181 | + try: |
| 182 | + # pad key value |
| 183 | + def pad_to_multiple(x, multiple, axis): |
| 184 | + # Pad in kernel |
| 185 | + return x |
| 186 | + seq_len = x.shape[axis] |
| 187 | + pad_len = (multiple - seq_len % multiple) % multiple |
| 188 | + if pad_len == 0: |
| 189 | + return x |
| 190 | + pad_width = [(0, 0)] * x.ndim |
| 191 | + pad_width[axis] = (0, pad_len) |
| 192 | + return jnp.pad(x, pad_width) |
| 193 | + |
| 194 | + padded_query = pad_to_multiple(query, bqsize, axis=2) |
| 195 | + padded_key = pad_to_multiple(key, bkvsize, axis=2) |
| 196 | + padded_value = pad_to_multiple(value, bkvsize, axis=2) |
| 197 | + |
| 198 | + jax.block_until_ready( |
| 199 | + _tpu_splash_attention( |
| 200 | + padded_query, |
| 201 | + padded_key, |
| 202 | + padded_value, |
| 203 | + mesh, |
| 204 | + bqsize, |
| 205 | + bkvsize, |
| 206 | + bkvcomputesize, |
| 207 | + bkvcomputesinize, |
| 208 | + ) |
| 209 | + ) |
| 210 | + |
| 211 | + start = time.perf_counter() |
| 212 | + jax.block_until_ready( |
| 213 | + _tpu_splash_attention( |
| 214 | + padded_query, |
| 215 | + padded_key, |
| 216 | + padded_value, |
| 217 | + mesh, |
| 218 | + bqsize, |
| 219 | + bkvsize, |
| 220 | + bkvcomputesize, |
| 221 | + bkvcomputesinize, |
| 222 | + ) |
| 223 | + ) |
| 224 | + end = time.perf_counter() |
| 225 | + print( |
| 226 | + f"{sp_dim=}, {bqsize}, {bkvsize}, {bkvcomputesize}, {bkvcomputesinize}, {end - start}, {padded_key.shape[2]}" |
| 227 | + ) |
| 228 | + except KeyboardInterrupt: |
| 229 | + raise |
| 230 | + except Exception: |
| 231 | + # raise |
| 232 | + continue |
| 233 | + break |
| 234 | + # smaller sp_dim better |
| 235 | + tp_dim //= 2 |
| 236 | + sp_dim *= 2 |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == "__main__": |
| 240 | + main() |
0 commit comments