Skip to content

CUDA: Optimize reduce_rows_f32 kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n #15132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

ORippler
Copy link
Contributor

@ORippler ORippler commented Aug 6, 2025

Investigation of Gemma3n perf on NVGPUs identified the reduce_rows_f32 kernel as a major performance bottleneck. Profiling revealed the kernel to be severely latency-limited in the regime run by by Gemma3n (nrows ~10, ncols in [2048, 8192]).

This PR addresses this issue, hiding the latency by a combination of:

  1. Manual loop unrolling, getting the compiler to request all unrolled datapoints at once, instead of fetching data sequentially (#pragma unroll did not do the trick unfortunately).
  2. Increasing the number of threads processing a row, where 512 threads are used for the low-parallelization regime (i.e. processing only a single row). This gives the SM 16 full warps to cycle through, further pipelining data fetching.

Since perf regressions were identified in the high-parallelization regime (nrows >= 2x SM count), we use:

  • 128 threads for medium-to-large columns, effectively letting each SM process a single row (a SM can execute 4 warps x 32 threads=128 threads concurrently).
  • As perf regression were still observed for small columns (< 1024 cols = 1 unrollment loop of a threadblock with size 128 and 8 unrolls), thread count was reduced to 32 threads for small columns. An alternative to this would have been to template the number of unrolls based on the column size. However, this would lead to an increased binary size due to the required compilation of multiple kernels, and was thus not pursued further.

The high/low parallelization threshold was empirically determined:

GPU Model Nrow SM Count Multiple, where 128 beats 512 threads
RTX 4000 SFF ADA 2.0x
RTX 6000 ADA 2.5x
RTX PRO 6000 Blackwell Max-Q 3.04x
RTX PRO 4500 Blackwell 3.15x

In total, up to ~25x perf improvement was observed on kernel-level.
speedup_comparison_multiple
Moreover, regression was not observed in any of the investigated combinations.
speedup_comparison_fractional

As a consequence of this general kernel optimization, Gemma3n achieves ~10% perf increase, going from 130 to 145 tok/s on a RTX PRO 6000 Blackwell-Max-Q with batch-size 1.

Naive:

  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           pp100 |        147.27 ± 0.82 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           tg100 |        130.75 ± 0.28 |

Optimized

  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           pp100 |        168.41 ± 0.37 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           tg100 |        146.68 ± 0.68 |

Side note: Similar tendencies were observed for rms_norm_f32, and we intend to optimize said kernel in a separate PR.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 6, 2025
This increases iteration cycle speed by not having to recompile
every kernel all the time
1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims
Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |
Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily
@ORippler ORippler force-pushed the osimons/optimize_reduce_rows_f32 branch from c6ed8cc to 9296d1f Compare August 7, 2025 07:46
@ORippler
Copy link
Contributor Author

ORippler commented Aug 7, 2025

Rebased on current master, resolving conflicts along the way. Reran E2E perf tests for gemma3n, and we continue to see perf gains. Nice to see some other optimizations for tg were made in master 😃

Naive:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl n_batch test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 146.89 ± 0.12
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 145.86 ± 0.13

Optimized:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl n_batch test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 167.47 ± 0.29
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 167.28 ± 0.35

float sum_temp[num_unroll] = { 0.0f };
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition would have been that it is faster not to add the inner loop due to this conditional statement. Just to be sure: did you test both versions?

Copy link
Contributor Author

@ORippler ORippler Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition would have been that it is faster not to add the inner loop due to this conditional statement. Just to be sure: did you test both versions?

We shared that intuition and, as mentioned in the PR description, one of the first things we tried was hinting the compiler to unroll the outer loop with #pragma unroll. Unfortunately, the compiler did not comply, and we were still seeing a lot of long scoreboard stalls caused by sequential iteration through the for loop (see the following two screenshots).

Screenshot 2025-08-07 at 13 28 59 image

Only by explicitly unrolling the loop did we get the compiler to comply and pre-fetch the data, effectively hiding the memory-latency (see 8 sequential FADDs preceeded by 8 sequential LDGs):
Screenshot 2025-08-07 at 13 28 39

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the reason the outer loop cannot be unrolled is simply because the number of iterations isn't known at compile time right? The inner loop has a fixed size and can therefore be unrolled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the reason the outer loop cannot be unrolled is simply because the number of iterations isn't known at compile time right? The inner loop has a fixed size and can therefore be unrolled.

In this case, the only pre-requisite for loop-unrolling followed by instruction reordering is an unaliased pointer, which we declare via __restrict__. nvcc did unroll the loop, but it did not reorder the instructions/batch the LDGs. We manually nudged it into the right direction by unrolling the loop, where the path to optimize becomes clearer to the compiler. Knowing the number of iterations at compile time is another example of such a nudge 😃

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have written this kernel differently. I would have made the CUDA block size a template parameter and increased it as long as it reduces the number of iterations needed (as is done in e.g. mmv.cu/mmvf.cu).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Templating the kernel also crossed our minds (see general PR description). However, templating would have lead to an increased size of the generated binaries and was thus not our preferred option given that it did not yield significant speed-ups in internal tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Gemma3n being bottlenecked by GGML_SUM or GGML_MEAN? The reason I'm asking is because GGML_SUM uses CUB while GGML_MEAN does not. I would welcome a better general kernel for reducing rows in ggml but I would assume that such a kernel would not be faster than CUB.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma3n is bottlenecked by GGML_OP_MEAN and GGML_OP_SUM_ROWS operations. We did not benchmark GGML_OP_SUM's reduce_rows_f32 and CUB execution paths against one another, so we cannot say which would be faster.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are interested in optimizing this model further, fusing these operations is likely to result in much better performance than optimizing the individual operations:

llama.cpp/src/llama-model.cpp

Lines 10729 to 10731 in 9a96389

ggml_tensor * calc_magnitude(ggml_tensor * x) {
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
}

llama.cpp/src/llama-model.cpp

Lines 10799 to 10807 in 9a96389

ggml_tensor * gaussian_topk(ggml_tensor * x) {
ggml_tensor * mean = ggml_mean(ctx0, x);
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
1.0f / (float)(x->ne[0] - 1)
));
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
}

Fused operations in the CUDA backend are handled here:

static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
}

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect
@JohannesGaessler
Copy link
Collaborator

Thank you for answering my questions (even though I could have gotten the answers by reading the PR description more carefully). If you test using CUB for GGML_MEAN this PR would essentially be good to merge from my side.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 7, 2025

Quick test shows this pr is also boradly performance positive on CDNA and performance neutral on RDNA2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants