From 2c82462f4073df0576463554d948da5e9e804cc2 Mon Sep 17 00:00:00 2001 From: Neha Abbas Date: Fri, 17 Oct 2025 22:08:50 -0700 Subject: [PATCH 1/4] updated optimization, fixed errors --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 22 ++- .../wgsl-shaders/set_rows.tmpl.wgsl | 163 ++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/set_rows.wgsl | 81 --------- 3 files changed, 181 insertions(+), 85 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b4558a9e3f1d2..e1324a2f3b090 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,7 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline; + webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type @@ -767,9 +767,20 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + // number of threads needed with vec4 = (total number of rows in matrix) * (number of elements in a row / 4) + uint32_t threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + + webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][0]; + // if not evenly divisble by 4, use the non-vectorized version + if (src->ne[0] % 4 != 0) { + pipeline = ctx->set_rows_pipeline[0][1]; + // threads = number of rows + threads = src->ne[1] * src->ne[2] * src->ne[3]; + } + + uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -1620,7 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", + // create_pipeline(device, pipeline, shader_code, label, constants) + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16, "set_rows_f16", + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl new file mode 100644 index 0000000000000..d4626bbd0b6b1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -0,0 +1,163 @@ +#define(VARIANTS) + +[ + { + "SHADER_SUFFIX": "f16_vec", + "REPLS": { + "TYPE" : "vec4", + "DST_TYPE": "vec4", + "BLOCK_SIZE": 4 + }, + "DECLS": ["F16_VEC"] + }, + { + "SHADER_SUFFIX": "f16", + "REPLS": { + "TYPE" : "f32", + "DST_TYPE": "f16", + "BLOCK_SIZE": 1 + }, + "DECLS": ["F16"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(F16_VEC) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let src_vec_index = (src_base + offset) / {{BLOCK_SIZE}}; + let dst_vec_index = (dst_base + offset) / {{BLOCK_SIZE}}; + dst[dst_vec_index] = vec4(src[src_vec_index]); +} +#enddecl(F16_VEC) + +#decl(F16) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = f16(src[src_base + offset]); +} +#enddecl(F16) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +DECLS + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +var dst: array<{{DST_TYPE}}>; + +@group(0) @binding(3) +var error: atomic; + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, // n_rows = ne1 = rows per slice + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(4) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + + // Determine the total number of threads based on mode + var max_threads: u32; + var i: u32; + if {{BLOCK_SIZE}} > 1 { + // Vectorized: one thread per vector of elements + // # of total rows to go through * (# of threads per row) + max_threads = (params.n_rows * params.ne2 * params.ne3) * (params.ne0 / {{BLOCK_SIZE}}); + + // calculations are based off i being row, but when vectorized, it corresponds to a vector in a row + // getting the row from gid + i = gid.x / (params.ne0 / {{BLOCK_SIZE}}); + } else { + // Non-vectorized: one thread per row + // # of total rows in the matrix + max_threads = params.n_rows * params.ne2 * params.ne3; + i = gid.x; // i corresponds to the row + } + + if (gid.x >= max_threads) { + return; + } + + + let i_src3 = i / (params.ne2 * params.n_rows); + + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; + + let idx_high_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1]; + + if (idx_low_val != 0) { + // Upper bits of index are not zero, output will be incorrect + atomicStore(&error, 1); + return; + } + + let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + + if {{BLOCK_SIZE}} > 1 { + // Vectorized: one thread per vector of elements + + // starts at what element of that row? + let element_offset = (gid.x % (params.ne0 / {{BLOCK_SIZE}})) * {{BLOCK_SIZE}}; + copy_elements(i_src_row, i_dst_row, element_offset); + + } else { + // Non-vectorized: go through each element in row, copy one by one + for (var i: u32 = 0; i < params.ne0; i++) { + copy_elements(i_src_row, i_dst_row, i); + } + } + + +} + +#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl deleted file mode 100644 index 3567713dc215c..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ /dev/null @@ -1,81 +0,0 @@ -enable f16; - -@group(0) @binding(0) -var src: array; - -@group(0) @binding(1) -var idx: array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var error: atomic; - -struct Params { - offset_src: u32, // in elements - offset_idx: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_idx0: u32, - stride_idx1: u32, - stride_idx2: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src - ne0: u32, - n_rows: u32, - ne2: u32, - ne3: u32, - - // Shape of idx - idx1: u32, - idx2: u32, -}; - -@group(0) @binding(4) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { - return; - } - var i = gid.x; - let i_src3 = i / (params.ne2 * params.n_rows); - - i = i % (params.ne2 * params.n_rows); - let i_src2 = i / params.n_rows; - let i_src1 = i % params.n_rows; - - let i_idx2 = i_src3 % params.idx2; - let i_idx1 = i_src2 % params.idx1; - let i_idx0 = i_src1; - - let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; - - let idx_high_val = idx[idx_high]; - let idx_low_val = idx[idx_high + 1]; - - if (idx_low_val != 0) { - // Upper bits of index are not zero, output will be incorrect - atomicStore(&error, 1); - return; - } - - let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; - let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - - for (var i: u32 = 0; i < params.ne0; i++) { - dst[i_dst_row + i] = f16(src[i_src_row + i]); - } -} From 66a80fc51a6acfcb7b36e35bed84106ec88c4761 Mon Sep 17 00:00:00 2001 From: Neha Abbas Date: Fri, 24 Oct 2025 13:48:02 -0700 Subject: [PATCH 2/4] non vectorized version now dispatches one thread per element --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +-- .../wgsl-shaders/set_rows.tmpl.wgsl | 32 ++++++------------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e1324a2f3b090..774b9a8354dc6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -774,8 +774,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, // if not evenly divisble by 4, use the non-vectorized version if (src->ne[0] % 4 != 0) { pipeline = ctx->set_rows_pipeline[0][1]; - // threads = number of rows - threads = src->ne[1] * src->ne[2] * src->ne[3]; + // threads = number of elements + threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl index d4626bbd0b6b1..e46d44a4365ac 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -97,26 +97,24 @@ fn main(@builtin(global_invocation_id) gid: vec3) { // Determine the total number of threads based on mode var max_threads: u32; - var i: u32; if {{BLOCK_SIZE}} > 1 { // Vectorized: one thread per vector of elements // # of total rows to go through * (# of threads per row) max_threads = (params.n_rows * params.ne2 * params.ne3) * (params.ne0 / {{BLOCK_SIZE}}); - - // calculations are based off i being row, but when vectorized, it corresponds to a vector in a row - // getting the row from gid - i = gid.x / (params.ne0 / {{BLOCK_SIZE}}); } else { - // Non-vectorized: one thread per row - // # of total rows in the matrix - max_threads = params.n_rows * params.ne2 * params.ne3; - i = gid.x; // i corresponds to the row + // Non-vectorized: one thread per element + // # of total elemtns in matrix + max_threads = params.ne0 * params.n_rows * params.ne2 * params.ne3; } if (gid.x >= max_threads) { return; } + // calculations are based off i being row, but when vectorized, it corresponds to a vector in a row + // getting the row from gid + var i = gid.x / (params.ne0 / {{BLOCK_SIZE}}); + let i_src3 = i / (params.ne2 * params.n_rows); @@ -142,20 +140,10 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - if {{BLOCK_SIZE}} > 1 { - // Vectorized: one thread per vector of elements - - // starts at what element of that row? - let element_offset = (gid.x % (params.ne0 / {{BLOCK_SIZE}})) * {{BLOCK_SIZE}}; - copy_elements(i_src_row, i_dst_row, element_offset); - - } else { - // Non-vectorized: go through each element in row, copy one by one - for (var i: u32 = 0; i < params.ne0; i++) { - copy_elements(i_src_row, i_dst_row, i); - } - } + // starts at what element of that row? + let element_offset = (gid.x % (params.ne0 / {{BLOCK_SIZE}})) * {{BLOCK_SIZE}}; + copy_elements(i_src_row, i_dst_row, element_offset); } From 2cc96eb87ba46607ef635d3034c770edbd5039e6 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 12:43:06 -0700 Subject: [PATCH 3/4] Simplify --- .../wgsl-shaders/set_rows.tmpl.wgsl | 58 ++++--------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl index e46d44a4365ac..4a6d819d3b145 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -6,47 +6,25 @@ "REPLS": { "TYPE" : "vec4", "DST_TYPE": "vec4", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F16_VEC"] + "VEC_SIZE": 4 + } }, { "SHADER_SUFFIX": "f16", "REPLS": { "TYPE" : "f32", "DST_TYPE": "f16", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] + "VEC_SIZE": 1 + } } ] #end(VARIANTS) -#define(DECLS) - -#decl(F16_VEC) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let src_vec_index = (src_base + offset) / {{BLOCK_SIZE}}; - let dst_vec_index = (dst_base + offset) / {{BLOCK_SIZE}}; - dst[dst_vec_index] = vec4(src[src_vec_index]); -} -#enddecl(F16_VEC) - -#decl(F16) -fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - dst[dst_base + offset] = f16(src[src_base + offset]); -} -#enddecl(F16) - -#end(DECLS) - #define(SHADER) enable f16; -DECLS - @group(0) @binding(0) var src: array<{{TYPE}}>; @@ -79,7 +57,7 @@ struct Params { // Shape of src ne0: u32, - n_rows: u32, // n_rows = ne1 = rows per slice + n_rows: u32, ne2: u32, ne3: u32, @@ -94,27 +72,13 @@ var params: Params; override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - - // Determine the total number of threads based on mode - var max_threads: u32; - if {{BLOCK_SIZE}} > 1 { - // Vectorized: one thread per vector of elements - // # of total rows to go through * (# of threads per row) - max_threads = (params.n_rows * params.ne2 * params.ne3) * (params.ne0 / {{BLOCK_SIZE}}); - } else { - // Non-vectorized: one thread per element - // # of total elemtns in matrix - max_threads = params.ne0 * params.n_rows * params.ne2 * params.ne3; - } - - if (gid.x >= max_threads) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) { return; } - // calculations are based off i being row, but when vectorized, it corresponds to a vector in a row // getting the row from gid - var i = gid.x / (params.ne0 / {{BLOCK_SIZE}}); - + let elems_per_row = params.ne0 / {{VEC_SIZE}}; + var i = gid.x / elems_per_row; let i_src3 = i / (params.ne2 * params.n_rows); @@ -141,10 +105,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; // starts at what element of that row? - let element_offset = (gid.x % (params.ne0 / {{BLOCK_SIZE}})) * {{BLOCK_SIZE}}; - - copy_elements(i_src_row, i_dst_row, element_offset); - + let col_idx = (gid.x % elems_per_row); + dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); } #end(SHADER) From 32ca54e7887ae8391a07fa8ea35182e988e78d98 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 12:46:25 -0700 Subject: [PATCH 4/4] Change logic for set_rows pipelines --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 774b9a8354dc6..353c7729bd1f8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,7 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) + webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type @@ -766,15 +766,15 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->max_wg_size_x; - // number of threads needed with vec4 = (total number of rows in matrix) * (number of elements in a row / 4) - uint32_t threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + size_t max_wg_size = ctx->max_wg_size_x; - webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][0]; + int vectorized = src->ne[0] % 4 == 0; + webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; // if not evenly divisble by 4, use the non-vectorized version - if (src->ne[0] % 4 != 0) { - pipeline = ctx->set_rows_pipeline[0][1]; - // threads = number of elements + uint32_t threads; + if (vectorized) { + threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } @@ -1631,11 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - // create_pipeline(device, pipeline, shader_code, label, constants) - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16, "set_rows_f16", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16_vec, "set_rows_f16_vec", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16, + "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec, + "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {