Skip to content

Metal kernel mv_f16_f32_l4 performance issue for long contexts, too many threads #6089

@izard

Description

@izard

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • I am running the latest code. Development is very rapid so there are no tagged versions as of now.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new bug or useful enhancement to share.

Feature Description

This is more a question if anyone is/was looking at the following long context performance issue in Metal. I did not find anything in the repo history, but maybe I just missed it.

When profiling long contexts (starting from about ~25K tokens), I found that block processing latency started being dominated by kernel_mul_mv_f16_f32_l4(width=small, like 128, height=large, like 32768 for context length slightly smaller than 32k, input vector length=32768). Running this kernel takes ~80% of total time, and this runtime is dominated by very low GPU execution units utilization caused by 32768 threads, each running very small chunk of work. There is no memory pressure.

So the code
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
and kernel can be somehow optimized not to start number of threads == vector length, but chunk the work differently

By the way, in this matrix to vector f16_f32_l4 kernel, lines
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 ) (src1 + r1nb11 + im*nb12);
..
}
are redundant as it is always called with nrows == 1, so when I replaced it with just
device const float4 * y4 = (device const float4 ) (src1 + imnb12);
nothing has changed (except the code became a tiny bit cleaner).

Motivation

Significant performance drop for long context prompts in Metal is caused by inefficient Metal threads scheduling, once fixed, I expect smaller time increase for longer contexts.
E.g. that is what I measured for one of the common models running on M3 Max:
context:323, t/s: 7.2
context:2248, t/s: 6.3
context:5908, t/s: 5.1
context:10314, t/s: 4.4
context:15112, t/s: 3.65
context:20556, t/s: 3.1
context:24588, t/s: 3

Possible Implementation

Change kernel_mul_mv_f16_f32_l4, or possibly add kernel_mul_mv_f16_f32_l4_long_vector with different threadGroups and threadsPerThreadgroup thread/blocks count

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions