Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,13 @@ extern "C" {
float scale,
float max_bias);

GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias);

GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
Expand Down
536 changes: 466 additions & 70 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#define(VARIANTS)

[
{
"REPLS": {
"TYPE": "f32",
}
},
{
"REPLS": {
"TYPE": "f16",
}
},
]

#end(VARIANTS)

#define(SHADER)
enable f16;

@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;

@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;

struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements

// Strides (in elements) — may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,

stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,

// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,

dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};

@group(0) @binding(2)
var<uniform> params: Params;

override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}

var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;

var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;

let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;

let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;

dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx]));
}
#end(SHADER)
84 changes: 84 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#define(VARIANTS)

[
{
"REPLS": {
"TYPE": "f32",
}
},
{
"REPLS": {
"TYPE": "f16",
}
},
]

#end(VARIANTS)

#define(SHADER)
enable f16;

@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;

struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements

// Strides (in elements) — may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,

stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,

// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,

dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};

@group(0) @binding(1)
var<uniform> params: Params;

override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}

var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;

var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;

let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;

let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;

dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx]));
}
#end(SHADER)
43 changes: 34 additions & 9 deletions ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,53 @@ var<storage, read_write> src: array<f32>;
DECLS

override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {

// one thread per row
var i = gid.x;
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;

let elems = (params.ne0 + wg_size - 1) / wg_size;

var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += src[i_src_row + j] * src[i_src_row + j];
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += wg_size;
}

scratch[lid.x] = sum;
workgroupBarrier();
var offset = wg_size / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
sum = scratch[0];

let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
for (var j: u32 = 0; j < params.ne0; j++) {
update(i_src_row + j, i_dst_row + j, scale);
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += wg_size;
}
}
#end(SHADER)
Loading