Skip to content

Commit 1f4dfb4

Browse files
Merge pull request #1 from yuyanpeng-google/yuyan-dev
Wan2.2 poc on v6e-16
2 parents af76988 + b845082 commit 1f4dfb4

File tree

9 files changed

+1986
-19
lines changed

9 files changed

+1986
-19
lines changed

exp/README.md

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Wan-AI/Wan2.2-I2V-A14B-Diffusers Recipe
2+
3+
1. Export the environment of GCP project
4+
* Fill the PROJECT_ID and TPU_NAME
5+
```
6+
### 1. export env of gcp ###
7+
8+
export PROJECT_ID=<project_id>
9+
export TPU_NAME=<tpu_name>
10+
export ZONE=<zone>
11+
export ACCELERATOR_TYPE=v6e-16
12+
export RUNTIME_VERSION=v2-alpha-tpuv6e
13+
```
14+
15+
2. Create the v6e-16 tpu vms on GCP
16+
```
17+
gcloud compute tpus tpu-vm create ${TPU_NAME} \
18+
--zone=${ZONE} \
19+
--project=${PROJECT_ID} \
20+
--accelerator-type=${ACCELERATOR_TYPE} \
21+
--version=${RUNTIME_VERSION}
22+
```
23+
24+
3. Prepare the python env on each tpu vms
25+
```
26+
### 3. prepare env on each host ###
27+
28+
run()
29+
{
30+
local command=$1
31+
local worker=${2:-all}
32+
gcloud compute tpus tpu-vm ssh --zone "${ZONE}" "${ACCOUNT}@${TPU_NAME}" --project "${PROJECT_ID}" --worker=${worker} --command="$command"
33+
}
34+
35+
BRANCH_NAME=wan2.2-main
36+
37+
SETUP_COMMAND="\
38+
set -x && \
39+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
40+
source ~/.local/bin/env && \
41+
uv venv -p 3.12 && \
42+
source .venv/bin/activate && \
43+
git clone -b ${BRANCH_NAME} https://github.com/yuyanpeng-google/diffusers.git || true && \
44+
cd diffusers && \
45+
uv pip install -e . && \
46+
uv pip install transformers accelerate && \
47+
uv pip install torch --index-url https://download.pytorch.org/whl/cpu && \
48+
uv pip install -U jax[tpu] && \
49+
uv pip install torchax && \
50+
uv pip install flax && \
51+
uv pip install ftfy imageio imageio-ffmpeg && \
52+
true
53+
"
54+
55+
run "${SETUP_COMMAND}"
56+
```
57+
58+
4. Run wan2.2 pipeline to generate the videos
59+
```
60+
### 4. run wan2.2 pipeline ###
61+
62+
run()
63+
{
64+
local command=$1
65+
local worker=${2:-all}
66+
gcloud compute tpus tpu-vm ssh --zone "${ZONE}" "${ACCOUNT}@${TPU_NAME}" --project "${PROJECT_ID}" --worker=${worker} --command="$command"
67+
}
68+
69+
BRANCH_NAME=wan2.2-main
70+
71+
RUN_COMMAND="\
72+
set -x && \
73+
source .venv/bin/activate && \
74+
killall -9 python || true && \
75+
sleep 10 && \
76+
export JAX_COMPILATION_CACHE_DIR="/dev/shm/jax_cache" && \
77+
export JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES=-1 && \
78+
export JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=0 && \
79+
export JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES='xla_gpu_per_fusion_autotune_cache_dir' && \
80+
export HF_HUB_CACHE=/dev/shm/hf_cache && \
81+
cd diffusers && \
82+
git fetch && git reset --hard origin/${BRANCH_NAME} && \
83+
cd exp && \
84+
nohup python wan2p2_benchmark.py > $(date +%Y-%m-%d_%H-%M-%S).log 2>&1 &
85+
true
86+
"
87+
run "${RUN_COMMAND}"
88+
```
89+
90+
5. See the results in stdout
91+
```
92+
...
93+
output video done. 20251029_093753.mp4
94+
Warmup and output video: 1961.571311s
95+
...
96+
Benchmark: 103.959559s
97+
Done
98+
```
99+
Notice that the first time warmup need to compile the graph which is time consuming.
100+
101+
6. Use scp download generated videos
102+
```
103+
VIDEO_NAME=20251029_093753.mp4 # from the 5 stdout
104+
105+
gcloud compute tpus tpu-vm scp --zone "${ZONE}" "${TPU_NAME}:~/diffusers/exp/${VIDEO_NAME}" . --project "${PROJECT_ID}" --worker=0
106+
```
107+
108+
109+
# Install
110+
111+
Install dependencies, setup virtual env first if required.
112+
113+
Test use python 3.12
114+
115+
```sh
116+
# install uv, python 3.12 and activate
117+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
118+
source ~/.local/bin/env && \
119+
uv venv -p 3.12 && \
120+
source .venv/bin/activate && \
121+
```
122+
123+
```sh
124+
# install dependency
125+
# pwd=.
126+
uv pip install -e . && \
127+
uv pip install transformers accelerate && \
128+
uv pip install torch --index-url https://download.pytorch.org/whl/cpu && \
129+
uv pip install -U jax[tpu] && \
130+
uv pip install torchax && \
131+
uv pip install flax && \
132+
uv pip install ftfy imageio imageio-ffmpeg
133+
```
134+
135+
To run:
136+
137+
```sh
138+
# cwd=exp
139+
python wan2p2_benchmark.py
140+
```
141+
142+
### Result
143+
144+
```
145+
# python wan2p2_benchmark.py
146+
Benchmark: 103.959559s
147+
Done
148+
```
149+
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import functools
2+
from jax.experimental.pallas.ops.tpu import splash_attention
3+
from jax.experimental.shard_map import shard_map
4+
from jax.sharding import NamedSharding
5+
from jax.sharding import PartitionSpec as P
6+
from jax.sharding import Mesh
7+
from jax.experimental import mesh_utils
8+
9+
import jax
10+
import jax.numpy as jnp
11+
import math
12+
import time
13+
14+
# import ringattention_pallas_tpu_splash
15+
import custom_splash_attention
16+
17+
18+
# Copy from wan_tx_splash_attn.py
19+
@functools.partial(
20+
jax.jit,
21+
static_argnames=("mesh", "bqsize", "bkvsize", "bkvcomputesize", "bkvcomputesinize"),
22+
)
23+
def _tpu_splash_attention(
24+
query,
25+
key,
26+
value,
27+
mesh,
28+
bqsize,
29+
bkvsize,
30+
bkvcomputesize,
31+
bkvcomputesinize,
32+
scale=None,
33+
is_causal=False,
34+
window_size=None,
35+
):
36+
num_heads = query.shape[1]
37+
38+
# The function that will be sharded across devices.
39+
def _attention_on_slices(q, k, v):
40+
41+
# Scale the query tensor. This happens on each device with its slice of data.
42+
scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale
43+
q = q * scale_factor
44+
45+
def pad_to_multiple2(x, multiple, axis):
46+
# For try pad outside
47+
return x, x.shape[axis]
48+
49+
# Helper to pad to next multiple
50+
def pad_to_multiple(x, multiple, axis):
51+
seq_len = x.shape[axis]
52+
pad_len = (multiple - seq_len % multiple) % multiple
53+
if pad_len == 0:
54+
return x, seq_len
55+
pad_width = [(0, 0)] * x.ndim
56+
pad_width[axis] = (0, pad_len)
57+
return jnp.pad(x, pad_width), seq_len
58+
59+
# This function operates on a single item from the batch.
60+
def kernel_3d(q_3d, k_3d, v_3d):
61+
q_seq_len = q_3d.shape[1]
62+
kv_seq_len = k_3d.shape[1]
63+
num_heads_on_device = q_3d.shape[0]
64+
65+
# Pad q, k, v to next multiple of BQSIZE/BKVSIZE
66+
q_3d_padded, q_orig_len = pad_to_multiple(q_3d, bqsize, axis=1)
67+
k_3d_padded, k_orig_len = pad_to_multiple(k_3d, bkvsize, axis=1)
68+
v_3d_padded, v_orig_len = pad_to_multiple(v_3d, bkvsize, axis=1)
69+
70+
padded_q_seq_len = q_3d_padded.shape[1]
71+
padded_kv_seq_len = k_3d_padded.shape[1]
72+
73+
block_sizes = splash_attention.BlockSizes(
74+
block_q=min(bqsize, padded_q_seq_len),
75+
block_kv=min(bkvsize, padded_kv_seq_len),
76+
block_kv_compute=min(bkvcomputesize, padded_kv_seq_len),
77+
)
78+
splash_kernel = custom_splash_attention.make_splash_mha(
79+
block_sizes=block_sizes, bkv_compute_in=bkvcomputesinize
80+
)
81+
out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded)
82+
# Remove padding if any
83+
out = jnp.swapaxes(out, 1, 2)
84+
return out[:, :q_orig_len, ...]
85+
86+
# Map the kernel over the batch dimension.
87+
vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0)
88+
return vmapped_kernel(q, k, v)
89+
90+
# Determine the partitioning spec based on the number of heads.
91+
if num_heads < mesh.size:
92+
# Replicated case for VAE. All devices get the full tensor.
93+
q_partition_spec = P()
94+
kv_partition_spec = P()
95+
else:
96+
# Sharded case for Transformer. Split along the heads axis.
97+
# Attn1 self attention, key length is long.
98+
if key.shape[2] > 10000:
99+
q_partition_spec = P("dp", "axis", "sp", None)
100+
kv_partition_spec = P("dp", "axis", None, None)
101+
else:
102+
# Attn2 which is cross attention, kv sequence is shorter. All gather the key value cost less.
103+
q_partition_spec = P("dp", None, ("axis", "sp"), None)
104+
kv_partition_spec = P("dp", None, None, None)
105+
106+
# ALWAYS use shard_map. The partition_spec will control the behavior.
107+
sharded_fn = shard_map(
108+
_attention_on_slices,
109+
mesh=mesh,
110+
in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec),
111+
out_specs=q_partition_spec,
112+
check_rep=False,
113+
)
114+
out = sharded_fn(query, key, value)
115+
out = jax.lax.with_sharding_constraint(out, P("dp", None, ("axis", "sp"), None))
116+
return out
117+
118+
119+
def main():
120+
query = jnp.ones((1, 40, 75600, 128))
121+
key = jnp.ones((1, 40, 75600, 128))
122+
value = jnp.ones((1, 40, 75600, 128))
123+
124+
bqsizes = (1512,)
125+
126+
# bqsizes = (600, 630, 675, 700, 720, 756, 840, 900, 945, 1008, 1050, 1080, 1200, 1260, 1350, 1400, 1512, 1575, 1680, 1800, 1890, 2100, 2160, 2520, 2700, 2800, 3024, 3150, 3600, 3780, 4200)
127+
bqsizes = range(2560, 4096, 256)
128+
bkvsizes = range(2560, 4096, 256)
129+
bkvcomputesizes = range(256, 4096, 256)
130+
# bkvcomputesinizes = range(64, 4096, 64)
131+
bkvcomputesinizes = range(256, 4096, 256)
132+
133+
# bqsizes = list(range(512, 4096, 128))
134+
# bkvsizes = (3072,)
135+
# bkvcomputesizes = (1024,)
136+
137+
# BQSIZE = 2816 # 2240 # 3024 #2520
138+
# BKVSIZE = 3840
139+
# BKVCOMPUTESIZE = 256
140+
141+
# bqsizes = (512,)
142+
# bkvsizes = (2048,)
143+
# bkvcomputesizes = (256,)
144+
145+
tp_dim = jax.device_count()
146+
dp_dim = 1
147+
sp_dim = 1
148+
print("sp, bqsize, bkvsize, bkvcomputesize, time (s), padded_key_size")
149+
while tp_dim >= 1:
150+
mesh_devices = mesh_utils.create_device_mesh(
151+
(tp_dim, dp_dim, sp_dim), allow_split_physical_axes=True
152+
)
153+
mesh = Mesh(mesh_devices, ("axis", "dp", "sp"))
154+
155+
query = jax.device_put(
156+
query, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))
157+
)
158+
key = jax.device_put(
159+
key, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))
160+
)
161+
value = jax.device_put(
162+
value, NamedSharding(mesh, P("dp", None, ("axis", "sp"), None))
163+
)
164+
with mesh:
165+
for bqsize in bqsizes:
166+
for bkvsize in bkvsizes:
167+
for bkvcomputesize in bkvcomputesizes:
168+
for bkvcomputesinize in bkvcomputesinizes:
169+
if (
170+
bkvsize < bkvcomputesize
171+
or bkvsize % bkvcomputesize != 0
172+
):
173+
continue
174+
175+
if (
176+
bkvcomputesize < bkvcomputesinize
177+
or bkvcomputesize % bkvcomputesinize != 0
178+
):
179+
continue
180+
181+
try:
182+
# pad key value
183+
def pad_to_multiple(x, multiple, axis):
184+
# Pad in kernel
185+
return x
186+
seq_len = x.shape[axis]
187+
pad_len = (multiple - seq_len % multiple) % multiple
188+
if pad_len == 0:
189+
return x
190+
pad_width = [(0, 0)] * x.ndim
191+
pad_width[axis] = (0, pad_len)
192+
return jnp.pad(x, pad_width)
193+
194+
padded_query = pad_to_multiple(query, bqsize, axis=2)
195+
padded_key = pad_to_multiple(key, bkvsize, axis=2)
196+
padded_value = pad_to_multiple(value, bkvsize, axis=2)
197+
198+
jax.block_until_ready(
199+
_tpu_splash_attention(
200+
padded_query,
201+
padded_key,
202+
padded_value,
203+
mesh,
204+
bqsize,
205+
bkvsize,
206+
bkvcomputesize,
207+
bkvcomputesinize,
208+
)
209+
)
210+
211+
start = time.perf_counter()
212+
jax.block_until_ready(
213+
_tpu_splash_attention(
214+
padded_query,
215+
padded_key,
216+
padded_value,
217+
mesh,
218+
bqsize,
219+
bkvsize,
220+
bkvcomputesize,
221+
bkvcomputesinize,
222+
)
223+
)
224+
end = time.perf_counter()
225+
print(
226+
f"{sp_dim=}, {bqsize}, {bkvsize}, {bkvcomputesize}, {bkvcomputesinize}, {end - start}, {padded_key.shape[2]}"
227+
)
228+
except KeyboardInterrupt:
229+
raise
230+
except Exception:
231+
# raise
232+
continue
233+
break
234+
# smaller sp_dim better
235+
tp_dim //= 2
236+
sp_dim *= 2
237+
238+
239+
if __name__ == "__main__":
240+
main()

0 commit comments

Comments
 (0)