Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ typedef struct {
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
Expand All @@ -257,6 +258,11 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;

typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;

typedef struct {
int32_t ne00;
int32_t ne02;
Expand Down
127 changes: 111 additions & 16 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
Expand Down Expand Up @@ -575,6 +579,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
Expand Down Expand Up @@ -1324,6 +1329,10 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
Expand Down Expand Up @@ -1609,6 +1618,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
Expand Down Expand Up @@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(

// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
const int ne11_mm_min = 4;
const int ne11_mm_min = 8;

// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
src0t == GGML_TYPE_F16 || // TODO: helper function
src0t == GGML_TYPE_F32 || // TODO: helper function
src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
Expand Down Expand Up @@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup

// num threads along row per simdgroup
int nxpsg = 0;
if (ne00 % 256 == 0 && ne11 < 3) {
nxpsg = 16;
} else if (ne00 % 128 == 0) {
nxpsg = 8;
} else {
nxpsg = 4;
}

const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
Expand All @@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = nil;

switch (src0->type) {
case GGML_TYPE_F32:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
Expand Down Expand Up @@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
Expand Down Expand Up @@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
Expand All @@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];

if (!use_vec_kernel) {
// half8x8 kernel
Expand All @@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(

while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
Expand All @@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(

const size_t smem = FATTN_SMEM(nsg);

[encoder setBuffer:id_dst offset:offs_dst atIndex:6];

//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable

GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
Expand All @@ -5561,37 +5593,100 @@ static int ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
// ne00*(nsg)
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))

int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;

// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));

int64_t nsg = 1;
while (nsg <= nsgt) {
nsg *= 2;
}
nsg /= 2;

const size_t smem = FATTN_SMEM(nsg);
// workgroups
// each workgroup handles nsg*nkpsg cache values
uint16_t nwg = 1;
if (4*nsg*nkpsg >= ne11) {
const size_t smem = FATTN_SMEM(nsg);

//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

// using 1 workgroup -> write the result directly into dst
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];

[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
nwg = 32;
nsg = MIN(4, nsg);

const size_t smem = FATTN_SMEM(nsg);

//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);

// sanity checks
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));

const int32_t nrows = ne1*ne2*ne3;

// temp buffer for writing the results from each workgroup
// - ne20: the size of the head vector
// - + 2: the S and M values for each intermediate result
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
if (!h_tmp) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
return 0;
}

//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);

[encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];

[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];

// reduce the results from the workgroups
{
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
nrows,
ne20,
};

id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;

[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:h_tmp offset:0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];

//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
}
}
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:
Expand Down
Loading
Loading