Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1001,20 +1001,25 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = MIN(32, ne00);
int nth = 32; // SIMD width

if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
do {
nth *= 2;
} while (nth <= ne00 && nth <= 1024);
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_DIAG_MASK_INF:
{
Expand Down
129 changes: 101 additions & 28 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -184,36 +184,73 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
}
const float max = simd_max(lmax);

float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max = buf[0];

// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
// wish to compute it twice.
pdst[i00] = exp_psrc0;
}

const float sum = simd_sum(lsum);
float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

sum = buf[0];

for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum;
}
}
Expand All @@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));

const float max = simd_max(lmax);
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max = buf[0];

// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];

const float sum = simd_sum(lsum);
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

sum = buf[0];

for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum;
}
}
Expand All @@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
}
}
}

kernel void kernel_diag_mask_inf_8(
Expand Down