Skip to content

Conversation

@wishstudio
Copy link
Contributor

This PR tries to implement graph plan APIs for the CUDA backend, as well as implement code in ggml-backend.cpp to actually use the graph plan APIs when a backend supports it.

The main functional improvement is to support cuda graphs when the graph is split (e.g. for hybrid inference). Currently the graph update and reuse logic (ggml_backend_sched_update_plans) is a simple heuristic: only try updating previous graphs when the number of splits and their corresponding backends are the same as the previous run. As the benchmark shown this universally accelerate hybrid inference tg performance by up to 30%.

The CUDA graph execution code is refactored and cleaned up. Two out of three original graph plan fail path are removed: disable_due_to_failed_graph_capture and disable_due_to_too_many_updates. The former one is due to the fact I found no code setting it to true. The latter is because I have currently no idea about the semantics in a split graph scenario. But it seems to not degrade the performance at all. Interestingly, I found that on my rig, even repeatitively build a graph then execute it only once is always faster than calling the kernels individually. I suspect it is the reason that the performance increased in tests even for CUDA only workloads, given this PR's optimization not targeting them. This of course needs to be verified on more hardware configurations.

Performance comparison:
RTX 5090 + 13700k, 128GB 6400 MT/s RAM

model n_cpu_moe test t/s master t/s pr Speedup
gpt-oss 20B MXFP4 MoE 0 pp512 9070.14 8768.06 0.97
gpt-oss 20B MXFP4 MoE 0 tg128 273.99 278.43 1.02
gpt-oss 20B MXFP4 MoE 99 pp512 916.16 931.72 1.02
gpt-oss 20B MXFP4 MoE 99 tg128 42.4 47.2 1.11
gpt-oss 120B MXFP4 MoE 24 pp512 150.76 150.31 1.00
gpt-oss 120B MXFP4 MoE 24 tg128 36.73 45.04 1.23
gpt-oss 120B MXFP4 MoE 99 pp512 187.69 186.21 0.99
gpt-oss 120B MXFP4 MoE 99 tg128 28.24 31.7 1.12
glm4moe 106B.A12B Q4_K 34 pp512 81.4 79.9 0.98
glm4moe 106B.A12B Q4_K 34 tg128 18.69 21.72 1.16
glm4moe 106B.A12B Q4_K 99 pp512 114.85 114.34 1.00
glm4moe 106B.A12B Q4_K 99 tg128 14.59 16.01 1.10
glm4moe 355B.A32B Q2_K 99 pp512 25.3 26.96 1.07
glm4moe 355B.A32B Q2_K 99 tg128 7.73 8.74 1.13
qwen3moe 235B.A22B Q3_K 80 pp512 59.99 61.26 1.02
qwen3moe 235B.A22B Q3_K 80 tg128 10.66 11.77 1.10
qwen3moe 235B.A22B Q3_K 99 pp512 85.38 88.45 1.04
qwen3moe 235B.A22B Q3_K 99 tg128 9.27 10.21 1.10
qwen3moe 30B.A3B Q4_K 0 pp512 6806.34 7295.45 1.07
qwen3moe 30B.A3B Q4_K 0 tg128 246.12 273 1.11
qwen3moe 30B.A3B Q4_K 99 pp512 531.72 560.52 1.05
qwen3moe 30B.A3B Q4_K 99 tg128 36.1 48.4 1.34
qwen3 8B Q8_0 pp512 10002.11 11383.42 1.14
qwen3 8B Q8_0 tg128 124.07 136.75 1.10
llama 13B Q6_K pp512 3886.45 4175.62 1.07
llama 13B Q6_K tg128 65.89 72.19 1.10
llama 8B Q8_0 pp512 10092.17 11702.73 1.16
llama 8B Q8_0 tg128 127.96 142.67 1.11
gemma3 12B Q8_0 pp512 6696.89 7920.99 1.18
gemma3 12B Q8_0 tg128 79.01 88.73 1.12
nemotron_h 9B Q8_0 pp512 7243.57 7808.74 1.08
nemotron_h 9B Q8_0 tg128 114.51 122.93 1.07

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 13, 2025
@AesSedai
Copy link

I did a sweep on my setup (12ch DDR5 6000, epyc 9355, two 3090s) using a 208.24 GiB (5.01 BPW) quant of GLM-4.6.

Barring the small bit where the master branch is faster at approximately sub-4096 context, there are some small speed improvements in TG across the context band up to the 32k that I tested.

./build/bin/llama-sweep-bench -m /mnt/srv/slush/gguf/GLM-4.6-GGUF/GLM-4.6-Q8_0-Q4_K-Q4_K-Q5_K.gguf --ctx-size 36000 --threads 54 --flash-attn 1 --n-gpu-layers 999 --batch-size 2048 --ubatch-size 2048 --override-tensor "blk\.(0|1|2)\.ffn_.*=CUDA0" --override-tensor "blk\..*_exps\.=CPU"
glm-4 6-cuda-graph

Data:

## llama.cpp [master, 7adc79c03]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |   10.778 |   190.03 |   34.731 |    14.74 |
|  2048 |    512 |   2048 |   11.032 |   185.64 |   35.848 |    14.28 |
|  2048 |    512 |   4096 |   11.278 |   181.60 |   45.684 |    11.21 |
|  2048 |    512 |   6144 |   11.564 |   177.10 |   48.760 |    10.50 |
|  2048 |    512 |   8192 |   11.883 |   172.35 |   50.487 |    10.14 |
|  2048 |    512 |  10240 |   12.091 |   169.38 |   52.143 |     9.82 |
|  2048 |    512 |  12288 |   12.342 |   165.94 |   53.575 |     9.56 |
|  2048 |    512 |  14336 |   12.684 |   161.46 |   54.839 |     9.34 |
|  2048 |    512 |  16384 |   13.013 |   157.38 |   56.572 |     9.05 |
|  2048 |    512 |  18432 |   13.455 |   152.21 |   57.690 |     8.88 |
|  2048 |    512 |  20480 |   13.967 |   146.63 |   58.853 |     8.70 |
|  2048 |    512 |  22528 |   14.293 |   143.29 |   60.505 |     8.46 |
|  2048 |    512 |  24576 |   15.116 |   135.48 |   61.862 |     8.28 |
|  2048 |    512 |  26624 |   15.800 |   129.62 |   63.202 |     8.10 |
|  2048 |    512 |  28672 |   16.189 |   126.51 |   64.354 |     7.96 |
|  2048 |    512 |  30720 |   16.723 |   122.46 |   65.799 |     7.78 |
|  2048 |    512 |  32768 |   17.266 |   118.62 |   67.169 |     7.62 |

## llama.cpp [cuda_graph_plan, c17f8b5bd]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |   10.814 |   189.38 |   45.181 |    11.33 |
|  2048 |    512 |   2048 |   11.060 |   185.17 |   44.959 |    11.39 |
|  2048 |    512 |   4096 |   11.325 |   180.83 |   45.945 |    11.14 |
|  2048 |    512 |   6144 |   11.583 |   176.82 |   47.814 |    10.71 |
|  2048 |    512 |   8192 |   11.854 |   172.78 |   48.996 |    10.45 |
|  2048 |    512 |  10240 |   12.159 |   168.44 |   50.569 |    10.12 |
|  2048 |    512 |  12288 |   12.399 |   165.17 |   52.207 |     9.81 |
|  2048 |    512 |  14336 |   12.695 |   161.32 |   53.232 |     9.62 |
|  2048 |    512 |  16384 |   13.048 |   156.96 |   54.649 |     9.37 |
|  2048 |    512 |  18432 |   13.525 |   151.43 |   55.856 |     9.17 |
|  2048 |    512 |  20480 |   14.021 |   146.06 |   57.113 |     8.96 |
|  2048 |    512 |  22528 |   14.335 |   142.86 |   58.808 |     8.71 |
|  2048 |    512 |  24576 |   15.223 |   134.54 |   60.189 |     8.51 |
|  2048 |    512 |  26624 |   15.867 |   129.08 |   61.226 |     8.36 |
|  2048 |    512 |  28672 |   16.317 |   125.51 |   62.301 |     8.22 |
|  2048 |    512 |  30720 |   16.757 |   122.22 |   64.025 |     8.00 |
|  2048 |    512 |  32768 |   17.275 |   118.56 |   65.629 |     7.80 |

@IMbackK
Copy link
Collaborator

IMbackK commented Oct 15, 2025

One thing to mention here is that on hip the generation of the graph is relatively more expensive than on cuda which makes the removal disable_due_to_too_many_updates have an perhaps outsized negative effect there. I will try to find the time to investigate.

@wishstudio
Copy link
Contributor Author

@AesSedai Could you give a try at gpt-oss-120b? I found that for some reason this PR applied to latest master works much slower for GLM 4.5 Air. Will need to investigate further. My testing shows gpt-oss-120b still has the same speedup.

@IMbackK My wild guess is that if the graph gets reused for a few times it should pay off the graph building overhead. Without this PR it is understandable to have this disable_due_to_too_many_updates condition because in a split graph nothing will get reused at all. But after this PR I'm curious at on what scenario are we expected to see a very high update frequency.

@AesSedai
Copy link

I'll have to download it and set it up, but I should be able to run it later tonight.

@DocShotgun
Copy link
Contributor

DocShotgun commented Oct 16, 2025

Ran some sweeps on my system for CPU+GPU inference - 5th gen Xeon with 8ch DDR5-5600 (it's a dual socket server but I bound to a single socket due to llama.cpp not having NUMA parallelism) + RTX PRO 6000.

On my setup maybe minimally slower on GLM 4.6 and minimally faster on Deepseek V3 0324 (or it's small enough to be within error, hard to say), but overall not significantly beneficial.

GLM 4.6 5.01bpw
~/numactl-bind-socket.sh --socket 0 --mode all ~/llama.cpp/build/bin/llama-sweep-bench -m /mnt/data/models/GGUF/GLM-4.6-Q8_0-Q4_K-Q4_K-Q5_K.gguf -c 32768 -ngl 999 -fa on -t 64 -b 4096 -ub 4096 --no-mmap --numa numactl -ot "blk\.(3[2-9]|[4-9][0-9])\.ffn_.*_exps.=CPU"

thud-sweep
## llama.cpp [master, 7adc79c03]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    4.627 |   885.15 |   83.217 |    12.31 |
|  4096 |   1024 |   4096 |    4.941 |   829.05 |   86.452 |    11.84 |
|  4096 |   1024 |   8192 |    5.259 |   778.81 |   87.578 |    11.69 |
|  4096 |   1024 |  12288 |    5.647 |   725.35 |   89.944 |    11.38 |
|  4096 |   1024 |  16384 |    5.970 |   686.06 |   91.353 |    11.21 |
|  4096 |   1024 |  20480 |    6.319 |   648.22 |   93.331 |    10.97 |
|  4096 |   1024 |  24576 |    6.682 |   612.98 |   95.518 |    10.72 |
|  4096 |   1024 |  28672 |    7.111 |   575.97 |   96.989 |    10.56 |

## llama.cpp [cuda_graph_plan, c17f8b5bd]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    4.633 |   884.15 |   84.082 |    12.18 |
|  4096 |   1024 |   4096 |    4.939 |   829.39 |   87.077 |    11.76 |
|  4096 |   1024 |   8192 |    5.264 |   778.08 |   88.174 |    11.61 |
|  4096 |   1024 |  12288 |    5.620 |   728.83 |   90.342 |    11.33 |
|  4096 |   1024 |  16384 |    5.965 |   686.63 |   91.545 |    11.19 |
|  4096 |   1024 |  20480 |    6.313 |   648.82 |   93.665 |    10.93 |
|  4096 |   1024 |  24576 |    6.681 |   613.04 |   95.739 |    10.70 |
|  4096 |   1024 |  28672 |    7.115 |   575.69 |   97.026 |    10.55 |

Deepseek V3 0324 4.93bpw
~/numactl-bind-socket.sh --socket 0 --mode all ~/llama.cpp/build/bin/llama-sweep-bench -m /mnt/data/models/GGUF/DeepSeek-V3-0324-Q8_0-Q4_K-Q4_K-Q5_K/DeepSeek-V3-0324-Q8_0-Q4_K-Q4_K-Q5_K-00001-of-00009.gguf -c 32768 -ngl 999 -fa on -t 64 -b 4096 -ub 4096 --no-mmap --numa numactl -ot "blk\.(1[3-9]|[2-9][0-9])\.ffn_.*_exps.=CPU"

ds-sweep
## llama.cpp [master, 7adc79c03]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    9.208 |   444.85 |  108.699 |     9.42 |
|  4096 |   1024 |   4096 |   10.458 |   391.64 |  109.773 |     9.33 |
|  4096 |   1024 |   8192 |   11.751 |   348.57 |  110.851 |     9.24 |
|  4096 |   1024 |  12288 |   13.056 |   313.73 |  113.031 |     9.06 |
|  4096 |   1024 |  16384 |   14.359 |   285.25 |  113.515 |     9.02 |
|  4096 |   1024 |  20480 |   15.721 |   260.54 |  113.881 |     8.99 |
|  4096 |   1024 |  24576 |   16.998 |   240.97 |  115.775 |     8.84 |
|  4096 |   1024 |  28672 |   18.169 |   225.44 |  116.941 |     8.76 |

## llama.cpp [cuda_graph_plan, c17f8b5bd]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    9.151 |   447.61 |  107.455 |     9.53 |
|  4096 |   1024 |   4096 |   10.361 |   395.32 |  107.960 |     9.48 |
|  4096 |   1024 |   8192 |   11.632 |   352.14 |  109.440 |     9.36 |
|  4096 |   1024 |  12288 |   12.910 |   317.28 |  111.556 |     9.18 |
|  4096 |   1024 |  16384 |   14.166 |   289.15 |  112.299 |     9.12 |
|  4096 |   1024 |  20480 |   15.481 |   264.58 |  113.112 |     9.05 |
|  4096 |   1024 |  24576 |   16.753 |   244.49 |  115.117 |     8.90 |
|  4096 |   1024 |  28672 |   18.022 |   227.28 |  115.524 |     8.86 |

@wishstudio
Copy link
Contributor Author

Looks like I screwed up the previous commits as the the graph update logic is completely broken. I can't even reproduce my previous results :(

Anyway the issues should now be fixed. Refactored a bit and added back the disable_due_to_too_many_updates condition. It currently works at the cuda graph level that if a single cuda graph changes too frequently we disable all future graph creation. It is not identical to previous behavior in non split graphs but should be close.

@AesSedai @DocShotgun Thank you for your tests! Please check if this commit (3afbd9f) works.

Performance comparison
model n_cpu_moe test t/s master (7adc79c) t/s pr (3afbd9f) Speedup
gpt-oss 20B MXFP4 MoE 0 pp512 9927.91 10180.03 1.03
gpt-oss 20B MXFP4 MoE 0 tg128 297.95 298.2 1.00
gpt-oss 20B MXFP4 MoE 99 pp512 987.16 921.35 0.93
gpt-oss 20B MXFP4 MoE 99 tg128 42.14 47.47 1.13
gpt-oss 120B MXFP4 MoE 24 pp512 305.3 282.15 0.92
gpt-oss 120B MXFP4 MoE 24 tg128 37.59 46.05 1.23
gpt-oss 120B MXFP4 MoE 99 pp512 216.95 202.49 0.93
gpt-oss 120B MXFP4 MoE 99 tg128 28.58 31.93 1.12
glm4moe 106B.A12B Q4_K 34 pp512 166.5 147.04 0.88
glm4moe 106B.A12B Q4_K 34 tg128 19.27 22.38 1.16
glm4moe 106B.A12B Q4_K 99 pp512 132.63 118.56 0.89
glm4moe 106B.A12B Q4_K 99 tg128 14.95 16.4 1.10
glm4moe 355B.A32B Q2_K 99 pp512 44.17 46.51 1.05
glm4moe 355B.A32B Q2_K 99 tg128 7.32 9.17 1.25
qwen3moe 235B.A22B Q3_K 80 pp512 106.98 107.55 1.01
qwen3moe 235B.A22B Q3_K 80 tg128 10.82 11.84 1.09
qwen3moe 235B.A22B Q3_K 99 pp512 89.78 93.14 1.04
qwen3moe 235B.A22B Q3_K 99 tg128 9.61 10.26 1.07
qwen3moe 30B.A3B Q4_K 0 pp512 7260.62 7336.87 1.01
qwen3moe 30B.A3B Q4_K 0 tg128 267.7 276.21 1.03
qwen3moe 30B.A3B Q4_K 99 pp512 564.09 579.92 1.03
qwen3moe 30B.A3B Q4_K 99 tg128 37.18 48.29 1.30
qwen3 8B Q8_0 pp512 10786.99 11638.03 1.08
qwen3 8B Q8_0 tg128 136.15 136.59 1.00
llama 13B Q6_K pp512 4208.3 4162.58 0.99
llama 13B Q6_K tg128 72.28 71.62 0.99
llama 8B Q8_0 pp512 10505.52 11916.53 1.13
llama 8B Q8_0 tg128 140.28 141.31 1.01
gemma3 12B Q8_0 pp512 7688.6 7593.39 0.99
gemma3 12B Q8_0 tg128 88.47 88.22 1.00
nemotron_h 9B Q8_0 pp512 8032.68 7850.61 0.98
nemotron_h 9B Q8_0 tg128 132.5 131.67 0.99

@DocShotgun
Copy link
Contributor

Unfortunately it seems with the new commit, I'm getting worse performance on GLM 4.6.

GLM 4.6 5.01bpw
~/numactl-bind-socket.sh --socket 0 --mode all ~/llama.cpp/build/bin/llama-sweep-bench -m /mnt/data/models/GGUF/GLM-4.6-Q8_0-Q4_K-Q4_K-Q5_K.gguf -c 32768 -ngl 999 -fa on -t 64 -b 4096 -ub 4096 --no-mmap --numa numactl -ot "blk\.(3[2-9]|[4-9][0-9])\.ffn_.*_exps.=CPU"

thud-sweep
## llama.cpp [master, 7adc79c03]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    4.627 |   885.15 |   83.217 |    12.31 |
|  4096 |   1024 |   4096 |    4.941 |   829.05 |   86.452 |    11.84 |
|  4096 |   1024 |   8192 |    5.259 |   778.81 |   87.578 |    11.69 |
|  4096 |   1024 |  12288 |    5.647 |   725.35 |   89.944 |    11.38 |
|  4096 |   1024 |  16384 |    5.970 |   686.06 |   91.353 |    11.21 |
|  4096 |   1024 |  20480 |    6.319 |   648.22 |   93.331 |    10.97 |
|  4096 |   1024 |  24576 |    6.682 |   612.98 |   95.518 |    10.72 |
|  4096 |   1024 |  28672 |    7.111 |   575.97 |   96.989 |    10.56 |

## llama.cpp [cuda_graph_plan, 3afbd9f32]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  4096 |   1024 |      0 |    4.635 |   883.69 |   84.188 |    12.16 |
|  4096 |   1024 |   4096 |    4.937 |   829.72 |   87.603 |    11.69 |
|  4096 |   1024 |   8192 |    5.268 |   777.46 |   88.877 |    11.52 |
|  4096 |   1024 |  12288 |    5.628 |   727.75 |   91.121 |    11.24 |
|  4096 |   1024 |  16384 |    5.972 |   685.81 |   92.545 |    11.06 |
|  4096 |   1024 |  20480 |    6.322 |   647.85 |   94.822 |    10.80 |
|  4096 |   1024 |  24576 |    6.689 |   612.37 |   96.780 |    10.58 |
|  4096 |   1024 |  28672 |    7.118 |   575.45 |   98.212 |    10.43 |

@AesSedai
Copy link

thud-sweep Agreed, slightly worse performance and a really, really weird dip at 18k for GLM-4.6
## llama.cpp [cuda_graph_plan, 3afbd9f32]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |   10.834 |   189.03 |   45.156 |    11.34 |
|  2048 |    512 |   2048 |   11.075 |   184.92 |   46.242 |    11.07 |
|  2048 |    512 |   4096 |   11.324 |   180.86 |   47.159 |    10.86 |
|  2048 |    512 |   6144 |   11.623 |   176.20 |   48.776 |    10.50 |
|  2048 |    512 |   8192 |   11.875 |   172.46 |   50.976 |    10.04 |
|  2048 |    512 |  10240 |   12.133 |   168.80 |   52.443 |     9.76 |
|  2048 |    512 |  12288 |   12.431 |   164.75 |   54.257 |     9.44 |
|  2048 |    512 |  14336 |   12.720 |   161.01 |   55.691 |     9.19 |
|  2048 |    512 |  16384 |   13.075 |   156.63 |   56.794 |     9.01 |
|  2048 |    512 |  18432 |   13.507 |   151.63 |   63.246 |     8.10 |
|  2048 |    512 |  20480 |   13.982 |   146.47 |   62.431 |     8.20 |
|  2048 |    512 |  22528 |   14.385 |   142.37 |   63.260 |     8.09 |
|  2048 |    512 |  24576 |   15.275 |   134.07 |   64.259 |     7.97 |
|  2048 |    512 |  26624 |   15.903 |   128.78 |   64.457 |     7.94 |
|  2048 |    512 |  28672 |   16.325 |   125.45 |   64.579 |     7.93 |
|  2048 |    512 |  30720 |   16.774 |   122.10 |   65.588 |     7.81 |
|  2048 |    512 |  32768 |   17.383 |   117.82 |   67.330 |     7.60 |

@AesSedai
Copy link

Here's GPT-OSS-120b, Q6_K 58.93 GiB (4.33 BPW)

./build/bin/llama-sweep-bench -m /mnt/srv/slush/gguf/openai_gpt-oss-120b-GGUF/openai_gpt-oss-120b-Q6_K/openai_gpt-oss-120b-Q6_K-00001-of-00002.gguf --ctx-size 36000 --threads 54 --flash-attn 1 --n-gpu-layers 999 --batch-size 2048 --ubatch-size 2048 --override-tensor "blk\.(0|1|2|3|4|5|6|7|8|9)\.ffn_.*=CUDA0" --override-tensor "blk\.(1[0-9])\.ffn_.*=CUDA1" --override-tensor "blk\..*_exps\.=CPU"
thud-sweep

Data:

## llama.cpp [master, 7adc79c03]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |    1.457 |  1405.46 |    7.680 |    66.67 |
|  2048 |    512 |   2048 |    1.468 |  1394.65 |    7.556 |    67.76 |
|  2048 |    512 |   4096 |    1.492 |  1373.09 |    7.738 |    66.17 |
|  2048 |    512 |   6144 |    1.516 |  1350.92 |    7.896 |    64.84 |
|  2048 |    512 |   8192 |    1.544 |  1326.65 |    7.752 |    66.04 |
|  2048 |    512 |  10240 |    1.575 |  1300.30 |    8.170 |    62.67 |
|  2048 |    512 |  12288 |    1.595 |  1284.28 |    8.609 |    59.47 |
|  2048 |    512 |  14336 |    1.623 |  1262.13 |    8.674 |    59.03 |
|  2048 |    512 |  16384 |    1.645 |  1244.91 |    9.215 |    55.56 |
|  2048 |    512 |  18432 |    1.679 |  1219.62 |    9.523 |    53.77 |
|  2048 |    512 |  20480 |    1.705 |  1201.47 |    9.777 |    52.37 |
|  2048 |    512 |  22528 |    1.734 |  1181.07 |   11.414 |    44.86 |
|  2048 |    512 |  24576 |    1.762 |  1162.57 |   10.055 |    50.92 |
|  2048 |    512 |  26624 |    1.764 |  1160.98 |   10.544 |    48.56 |
|  2048 |    512 |  28672 |    1.789 |  1144.79 |   11.875 |    43.12 |
|  2048 |    512 |  30720 |    1.822 |  1123.95 |   11.337 |    45.16 |
|  2048 |    512 |  32768 |    1.867 |  1097.12 |   11.464 |    44.66 |

## llama.cpp [cuda_graph_plan, 3afbd9f32]
|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |    1.450 |  1412.23 |    7.234 |    70.78 |
|  2048 |    512 |   2048 |    1.468 |  1395.31 |    7.162 |    71.49 |
|  2048 |    512 |   4096 |    1.487 |  1377.37 |    7.176 |    71.35 |
|  2048 |    512 |   6144 |    1.511 |  1355.67 |    7.227 |    70.85 |
|  2048 |    512 |   8192 |    1.537 |  1332.19 |    7.490 |    68.36 |
|  2048 |    512 |  10240 |    1.572 |  1302.69 |    7.428 |    68.93 |
|  2048 |    512 |  12288 |    1.592 |  1286.48 |    7.703 |    66.46 |
|  2048 |    512 |  14336 |    1.621 |  1263.10 |    7.812 |    65.54 |
|  2048 |    512 |  16384 |    1.647 |  1243.46 |    7.936 |    64.51 |
|  2048 |    512 |  18432 |    1.695 |  1208.44 |    7.917 |    64.67 |
|  2048 |    512 |  20480 |    1.708 |  1199.09 |    8.117 |    63.08 |
|  2048 |    512 |  22528 |    1.736 |  1179.57 |    8.034 |    63.73 |
|  2048 |    512 |  24576 |    1.756 |  1166.31 |    8.738 |    58.59 |
|  2048 |    512 |  26624 |    1.758 |  1164.95 |    8.389 |    61.03 |
|  2048 |    512 |  28672 |    1.792 |  1142.98 |   10.209 |    50.15 |
|  2048 |    512 |  30720 |    1.831 |  1118.34 |    9.938 |    51.52 |
|  2048 |    512 |  32768 |    1.859 |  1101.79 |   11.045 |    46.35 |

@DocShotgun
Copy link
Contributor

DocShotgun commented Oct 16, 2025

Interesting. Seems like it's potentially beneficial in certain cases but not in others (and potentially a regression in some cases). Sadly not the free lunch I was hoping for lol.

@ubergarm
Copy link

@wishstudio

Could you give a try at gpt-oss-120b?

I just rebuilt fresh and did a 3-way comparison between mainline lcpp master, this PR, and ik_llama.cpp@main (including ikawrakow/ik_llama.cpp#829 which helps a lot with -ooae over there).

tl;dr; this PR is possibly giving a bit over +2% TG over master with basically the same PP in quick llama-sweep-bench testing on my hybrid GPU+CPU gaming rig.

sweep-bench-gpt-oss-120b
👈 Details
model=/mnt/ai/models/ggml-org/gpt-oss-120b-GGUF/gpt-oss-120b-mxfp4-00001-of-00003.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -fa 1 \
  -c 20480 \
  -ub 4096 -b 4096 \
  -ngl 99 \
  -ot "blk\.(0|1|2|3|4|5|6|7|8)\.ffn.*=CUDA0" \
  -ot "ffn_(up|gate|down)_exps\.weight=CPU" \
  --threads 16 \
  --no-mmap

# for ik_llama.cpp add
  -fmoe \
  -ooae \
  --warmup-batch

ik_llama.cpp/main@dbfd1515 -fmoe -ooae --warmup-batch

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 1024 0 2.020 2027.71 24.496 41.80
4096 1024 4096 2.081 1968.45 24.744 41.38
4096 1024 8192 2.155 1900.88 25.003 40.96
4096 1024 12288 2.219 1846.10 25.133 40.74
4096 1024 16384 2.318 1767.15 25.437 40.26

llama.cpp/master@683fa6ba

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 1024 0 2.022 2026.09 25.375 40.35
4096 1024 4096 2.070 1979.06 25.972 39.43
4096 1024 8192 2.158 1897.63 26.271 38.98
4096 1024 12288 2.253 1818.00 26.629 38.45
4096 1024 16384 2.349 1743.98 26.990 37.94

llama.cpp/wishstudio:cuda_graph_plan PR16548 3afbd9f rebased to 683fa6b + ug/port-sweep-bench

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 1024 0 2.026 2021.82 24.763 41.35
4096 1024 4096 2.078 1970.74 25.348 40.40
4096 1024 8192 2.164 1892.79 25.678 39.88
4096 1024 12288 2.241 1827.42 26.030 39.34
4096 1024 16384 2.352 1741.49 26.337 38.88

@aendk
Copy link
Contributor

aendk commented Oct 17, 2025

Super interesting PR!

Off-topic to this PR itself, but relevant to the discussion: What is llama-sweep-bench? How do I build it?
I found nothing about it in the master branch when building by default, and nothing online.

@wishstudio
Copy link
Contributor Author

@DocShotgun I guess it's expected that this optimization is model and machine specific 😄

@DocShotgun @AesSedai @ubergarm Looks like all of you are running this in Linux and I'm running Win11. PP performance is very unstable in my tests but it looks much more stable on your tests. During my tests I also see lots of random latencies and slowdowns during kernel launches. This could explain that on my system this PR removes much more overhead than on your systems.

But anyway even if this PR does not help much on your systems, I didn't expect it to cause any slowdown. Theoretically the only thing this PR changes on these MoE models are just making a consolidated cudaGraphLaunch call instead of a series of kernel launches. The model you tested is out of reach for my current system. I tested a smaller quant (unsloth/GLM-4.6-UD-IQ1_S-00001-of-00002.gguf) and still can still measure performance gain on my system (6.14 to 6.71 tg). The only thing I can think of is because you were using a larger quant than mine, maybe it involved some different kernels which did not play nice with graph updates or triggered some bugs. But without detailed profiling information I'm kind of stuck here.

@aendk I spent like 10 minutes searching for it and ended up with an example in ik_llama.

@ORippler
Copy link
Contributor

@DocShotgun @AesSedai @ubergarm Looks like all of you are running this in Linux and I'm running Win11. PP performance is very unstable in my tests but it looks much more stable on your tests. During my tests I also see lots of random latencies and slowdowns during kernel launches. This could explain that on my system this PR removes much more overhead than on your systems.

I can confirm that kernel launch overheads are significantly more expensive and varying on Win11 compared to Linux for NVGPUS due to their batched dispatch via WDDM2.

@DocShotgun
Copy link
Contributor

I've heard the same thing about kernel launch overhead being significantly higher on Windows than Linux. I notice you mention in the OP that you found building a graph and executing it only once to be always faster than calling the kernels individually - perhaps the lower kernel launch overhead on Linux makes this not necessarily true there?

@ubergarm
Copy link

@aendk

Off-topic to this PR itself, but relevant to the discussion: What is llama-sweep-bench? How do I build it?
I found nothing about it in the master branch when building by default, and nothing online.

Its a long story but I believe llama-sweep-bench was inspired by an old unmerged PR here that got bounced around on another fork, and I keep a rough port here: https://github.com/ubergarm/llama.cpp/tree/ug/port-sweep-bench

Some folks use it to sample PP/TG speeds across a given kv-cache depth and use vibe coded python scripts to plot/visualize the results across multiple runs. I'd be happy to see it added officially but haven't fixed up the arguments enough to open a proper PR. Go ask over on https://huggingface.co/BeaverAI discord if u want more details or example usage.

@wishstudio
Copy link
Contributor Author

I can confirm that kernel launch overheads are significantly more expensive and varying on Win11 compared to Linux for NVGPUS due to their batched dispatch via WDDM2.

Another reason to switch to Linux 🥲

I've heard the same thing about kernel launch overhead being significantly higher on Windows than Linux. I notice you mention in the OP that you found building a graph and executing it only once to be always faster than calling the kernels individually - perhaps the lower kernel launch overhead on Linux makes this not necessarily true there?

Could be. But in the last commit I introduced back the disable_due_to_too_many_updates guard that graph creation will be completely disabled after 4 consecutive graph updates. So unless it is a weird change pattern (like the graph updates 1-3 times and then gets reused 1 time, repetitively) it should no longer apply. In my tests with a lower quant of GLM 4.6 I can see the graphs are getting reused most of the time, updating only once every ~128 tokens. Another possibility if (overhead of graph manipulation) / 128 > (overhead reduction of graph launch).

@timkhronos
Copy link

I tested this on windows, on GLM 4.6, on a 5090 + 9950x3d. I saw an average loss of about 8% compared to mainline, taking me from 5.2 to 4.7 tok/s using this quant.

Copy link
Contributor

@ORippler ORippler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I welcome the proposal, I feel there are issues that still need to be addressed, namely:

  1. stricter equivalence checking when updating graph plans, and
  2. fixing the logic in the cuda backend

To sched some light on the perf gains here some profiles (with locally hacked fix of point 2 raised above), where I offloaded a single FFN with

sudo nsys profile --output=gp_offload --trace=cuda,nvtx,osrt --sample=cpu --gpu-metrics-devices=all --cuda-graph-trace=node ./build-x64-linux-gcc-reldbg/bin/llama-cli -m models/gemma-3n-E2B-it-Q8_0.gguf -fa 1 -p "Hello World" --top-k 1 -ot "blk\.1.ffn*=CPU" --no-conversation -n 10
image

One can see that CUDA Graphs nicely speed up the GPU portion of the workload (kuda kernels run 100% of the time, as we eliminate kernel-launch overheads). However, the CPU workload does not profit at all. Given that we are on heterogeneous systems, optimizing GPU workload may actually induce thermal throttling on CPU as CPU will be active more frequently in a given timeframe. Thus, slowdowns can occur in E2E tests.

More performant approaches for models that are too large to fit in vram would be ones where weights are streamed async onto the GPU (e.g. while previous workload is running either on GPU or CPU). Though such approaches are also more complex.

Comment on lines 3129 to 3142
// check if we are doing a graph update
if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
|| cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
|| use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch
cuda_graph->number_consecutive_updates++;
if (cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->disable_graph_due_to_too_many_updates = true;
use_cuda_graph = false;
} else {
cuda_graph_update_required = true;
}
} else {
cuda_graph->number_consecutive_updates = 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is faulty. Before, the update check on the cuda graph was performed for every decoded token, whereas now it's done only when sched->plan_dirty = true (which is set in ggml_backend_sched_split_graph, which is invoked only in ggml_backend_sched_alloc_graph/ggml_backend_sched_reserve). As a consequence, we no longer call the function when the cuda graph does not need to be updated, leading to number_conseuctive_updates being increased and cuda graphs being disabled eventually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new number_consecutive_computes tracker, maintaining it in ggml_backend_cuda_graph_plan_compute, and clears the number_consecutive_updates tracker when a graph is used for more than one time.

struct ggml_backend_sched_plan * plans;
int n_plans;
int plans_capacity;
bool plan_dirty;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool plan_dirty;
bool plan_needs_update;

Nit: not sure if dirty is the appropriate wording here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

Comment on lines +1464 to +1467
if (sched->splits[i].backend_id != sched->plans[i].backend_id) {
create_new_plans = true;
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo, we should actually check for the full equivalence of the generated splits (i.e. the ggml_cgraph objects contained in the splits & plans instead of just their backends)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current check is a bare minimum for ensuring correctness. It basically says "if the number of splits or backend of any split is changed we rebuilt everything". If only the interior of a split is changed, it will be delegated to the backend's graph update routine to handle the change (code is just next paragraph below this check). In current CUDA backend, if the graph change is too large the cudaGraphExecUpdate will fail and cudaGraphInstantiate will be called to reinstantiate a new executable graph. I think this will not generate too much overhead as the cudaGraphInstantiate is inevitable and only the cudaGraphExecUpdate call can be saved.

In order to improve this, I guess something like the enhanced API in #14514 needs to be implemented. It passes more information to help the backend make an informed decision on recreate vs update. Anyway I believe most scaffolds are missing so this would be out of scope for this PR.

Copy link
Contributor Author

@wishstudio wishstudio Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, if the split graph is indeed identical, the CUDA backend will find out via the is_cuda_graph_update_required fast path. I don't think it's a good idea to move the checks up to the backend scheduler as different backends could have different definition of equivalence depending on various factors as discussed in #14514. We may calculate these factors in the scheduler but the ultimate decision is still better be done in the backend.

@wishstudio
Copy link
Contributor Author

Added the ggml_backend_supports_graph_plan and ggml_backend_supports_graph_plan_update APIs to ggml-backend.h to make CI happy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants