From 4529332e2296298cbe400fdd555d276f86cec0dc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Sep 2025 06:20:52 +0700 Subject: [PATCH 01/40] webgpu : fix build on emscripten --- .gitignore | 3 +++ common/common.cpp | 2 ++ ggml/src/ggml-webgpu/CMakeLists.txt | 13 +++++++++++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 12 ++++++++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 595831accb05d..ed034e40e2795 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ poetry.toml /run-vim.sh /run-chat.sh .ccache/ + +# emscripten +a.out.* diff --git a/common/common.cpp b/common/common.cpp index 0c92d4d57ddbf..a95cca8a2b174 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -871,6 +871,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 78a985a4d167a..dce7e14ca83c9 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -39,8 +39,17 @@ add_dependencies(ggml-webgpu generate_shaders) if(EMSCRIPTEN) set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") - target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + if(NOT EMDAWNWEBGPU_DIR) + # default built-in port + target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + else() + # custom port + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + set(DawnWebGPU_TARGET webgpu_cpp) else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e5df883c1367e..acd43a8a94b97 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -9,6 +9,10 @@ #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#ifdef __EMSCRIPTEN__ +#include +#endif + #include #include @@ -1173,8 +1177,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetInfo(&info); // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; + std::vector required_features = { + wgpu::FeatureName::ShaderF16, +#ifndef __EMSCRIPTEN__ + wgpu::FeatureName::ImplicitDeviceSynchronization, +#endif + }; wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); From 990a98ae642cec417623081dd273b1ff5b1aea02 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Sep 2025 08:36:16 +0700 Subject: [PATCH 02/40] more debugging stuff --- CMakeLists.txt | 6 +++++- ggml/src/ggml-webgpu/CMakeLists.txt | 8 ++++++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 36a2078e4c9fa..9e0a042edb2fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,7 +36,11 @@ option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) - option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) + option(LLAMA_BUILD_HTML "llama: build HTML file" ON) + if (LLAMA_BUILD_HTML) + set(CMAKE_EXECUTABLE_SUFFIX ".html") + endif() else() if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index dce7e14ca83c9..73eb66fb874c3 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -42,11 +42,11 @@ if(EMSCRIPTEN) if(NOT EMDAWNWEBGPU_DIR) # default built-in port target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") - target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu" "-sASYNCIFY=1") else() # custom port target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py" "-sASYNCIFY=1") endif() set(DawnWebGPU_TARGET webgpu_cpp) @@ -57,6 +57,10 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) + if(EMSCRIPTEN) + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2" "-fexceptions") + endif() endif() target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR}) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index acd43a8a94b97..53cd8df50b836 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1295,6 +1295,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + +#ifdef __EMSCRIPTEN__ +#ifndef __EMSCRIPTEN_PTHREADS__ + GGML_LOG_WARN("ggml_webgpu: pthread is disabled. This may cause bugs\n"); +#endif + + if (webgpu_ctx->instance == nullptr) { + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure -sASYNCIFY is set\n"); + return nullptr; + } +#endif GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { From 5616b9c246d6d50c1419b6a14703d62818adf579 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Sep 2025 19:26:22 +0700 Subject: [PATCH 03/40] test-backend-ops: force single thread on wasm --- tests/test-backend-ops.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d638a96ee9be8..af8add41b76ba 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -40,12 +40,18 @@ #include #include +#ifdef __EMSCRIPTEN__ +# define N_THREADS 1 +#else +# define N_THREADS std::thread::hardware_concurrency() +#endif + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); { // parallel initialization - static const size_t n_threads = std::thread::hardware_concurrency(); + static const size_t n_threads = N_THREADS; // static RNG initialization (revisit if n_threads stops being constant) static std::vector generators = []() { std::random_device rd; @@ -104,7 +110,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, + const size_t n_threads = std::min(N_THREADS/2, std::max(1, n_blocks / min_blocks_per_thread)); std::vector> tasks; tasks.reserve(n_threads); @@ -6934,7 +6940,7 @@ int main(int argc, char ** argv) { auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { // TODO: better value for n_threads - ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); + ggml_backend_set_n_threads_fn(backend, N_THREADS); } size_t free, total; // NOLINT From 56d02f6f424a75c7fc085effe4c8bc078680058e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Sep 2025 19:34:27 +0700 Subject: [PATCH 04/40] fix single-thread case for init_tensor_uniform --- tests/test-backend-ops.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index af8add41b76ba..86f8087c0fd25 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -70,15 +70,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } }; - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*nels/n_threads; - size_t end = (i+1)*nels/n_threads; - tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); - } - for (auto & t : tasks) { - t.get(); + if (n_threads == 1) { + init_thread(0, 0, nels); + } else { + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } From 1cd87e07df867dd25a90d52761bce7732f97c864 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 7 Sep 2025 00:04:32 +0700 Subject: [PATCH 05/40] use jspi --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-webgpu/CMakeLists.txt | 15 +++++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 9ef88c6fd0a85..a67e3421d0554 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -189,6 +189,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_WEBGPU "ggml: use WebGPU" OFF) option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) +option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 73eb66fb874c3..7ab450564c3b3 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -42,11 +42,19 @@ if(EMSCRIPTEN) if(NOT EMDAWNWEBGPU_DIR) # default built-in port target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") - target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu" "-sASYNCIFY=1") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") else() # custom port target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py" "-sASYNCIFY=1") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + if (GGML_WEBGPU_JSPI) + target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions") + target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions") + else() + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") endif() set(DawnWebGPU_TARGET webgpu_cpp) @@ -58,8 +66,7 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) if(EMSCRIPTEN) - target_compile_options(ggml-webgpu PRIVATE "-fexceptions") - target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2" "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2") endif() endif() diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 53cd8df50b836..35503093b4bb1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1297,12 +1297,8 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); #ifdef __EMSCRIPTEN__ -#ifndef __EMSCRIPTEN_PTHREADS__ - GGML_LOG_WARN("ggml_webgpu: pthread is disabled. This may cause bugs\n"); -#endif - if (webgpu_ctx->instance == nullptr) { - GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure -sASYNCIFY is set\n"); + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); return nullptr; } #endif From 8549245c71eef0c1b269f0a56c88e8e9d52d454f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 8 Sep 2025 04:17:01 +0700 Subject: [PATCH 06/40] add pthread --- CMakeLists.txt | 5 ++ scripts/serve-static.js | 110 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 scripts/serve-static.js diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e0a042edb2fc..8099cc3be03c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,11 @@ endif() # 3rd-party # +if (EMSCRIPTEN) + add_compile_options(-pthread) + link_libraries (-pthread) +endif() + if (LLAMA_USE_SYSTEM_GGML) message(STATUS "Using system-provided libggml, skipping ggml build") find_package(ggml REQUIRED) diff --git a/scripts/serve-static.js b/scripts/serve-static.js new file mode 100644 index 0000000000000..df6cf534055f1 --- /dev/null +++ b/scripts/serve-static.js @@ -0,0 +1,110 @@ +const http = require('http'); +const fs = require('fs').promises; +const path = require('path'); + +// This file is used for testing wasm build from emscripten +// Example build command: +// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF +// cmake --build build-wasm --target test-backend-ops -j + +const PORT = 8080; +const STATIC_DIR = path.join(__dirname, '../build-wasm/bin'); +console.log(`Serving static files from: ${STATIC_DIR}`); + +const mimeTypes = { + '.html': 'text/html', + '.js': 'text/javascript', + '.css': 'text/css', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.gif': 'image/gif', + '.svg': 'image/svg+xml', + '.json': 'application/json', + '.woff': 'font/woff', + '.woff2': 'font/woff2', +}; + +async function generateDirListing(dirPath, reqUrl) { + const files = await fs.readdir(dirPath); + let html = ` + + + + Directory Listing + + + +

Directory: ${reqUrl}

+
    + `; + + if (reqUrl !== '/') { + html += `
  • ../ (Parent Directory)
  • `; + } + + for (const file of files) { + const filePath = path.join(dirPath, file); + const stats = await fs.stat(filePath); + const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); + html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; + } + + html += ` +
+ + + `; + return html; +} + +const server = http.createServer(async (req, res) => { + try { + // Set COOP and COEP headers + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin'); + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp'); + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate'); + res.setHeader('Pragma', 'no-cache'); + res.setHeader('Expires', '0'); + + const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url)); + const stats = await fs.stat(filePath); + + if (stats.isDirectory()) { + const indexPath = path.join(filePath, 'index.html'); + try { + const indexData = await fs.readFile(indexPath); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(indexData); + } catch { + // No index.html, generate directory listing + const dirListing = await generateDirListing(filePath, req.url); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(dirListing); + } + } else { + const ext = path.extname(filePath).toLowerCase(); + const contentType = mimeTypes[ext] || 'application/octet-stream'; + const data = await fs.readFile(filePath); + res.writeHeader(200, { 'Content-Type': contentType }); + res.end(data); + } + } catch (err) { + if (err.code === 'ENOENT') { + res.writeHeader(404, { 'Content-Type': 'text/plain' }); + res.end('404 Not Found'); + } else { + res.writeHeader(500, { 'Content-Type': 'text/plain' }); + res.end('500 Internal Server Error'); + } + } +}); + +server.listen(PORT, () => { + console.log(`Server running at http://localhost:${PORT}/`); +}); From bf9d14cd01b1ac166308b77d2f710b374acce932 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 8 Sep 2025 17:12:22 +0700 Subject: [PATCH 07/40] test: remember to set n_thread for cpu backend --- CMakeLists.txt | 5 ----- tests/test-backend-ops.cpp | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8099cc3be03c8..9e0a042edb2fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,11 +161,6 @@ endif() # 3rd-party # -if (EMSCRIPTEN) - add_compile_options(-pthread) - link_libraries (-pthread) -endif() - if (LLAMA_USE_SYSTEM_GGML) message(STATUS "Using system-provided libggml, skipping ggml build") find_package(ggml REQUIRED) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 86f8087c0fd25..676af05d6dfb4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -6704,6 +6705,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op return false; } + // TODO: find a better way to set the number of threads for the CPU backend + ggml_backend_cpu_set_n_threads(backend_cpu, N_THREADS); + size_t n_ok = 0; for (auto & test : test_cases) { if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) { From b56681191313ee1c78e7717374e2ec8210972616 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 15 Oct 2025 19:04:48 +0800 Subject: [PATCH 08/40] Add buffer label and enable dawn-specific toggles to turn off some checks --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 56 ++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 05e16cd432ad3..b4558a9e3f1d2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -309,10 +309,12 @@ struct ggml_backend_webgpu_context { struct ggml_backend_webgpu_buffer_context { webgpu_context webgpu_ctx; wgpu::Buffer buffer; + std::string label; - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : + ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : webgpu_ctx(std::move(ctx)), - buffer(std::move(buf)) {} + buffer(std::move(buf)), + label(std::move(lbl)) {} }; /* End struct definitions */ @@ -1336,11 +1338,11 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " - << offset << ", " << size << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value + << ", " << offset << ", " << size << ")"); + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. @@ -1354,12 +1356,13 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, const void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " - << offset << ", " << size << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data + << ", " << offset << ", " << size << ")"); + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -1397,12 +1400,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " - << offset << ", " << size << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data + << ", " << offset << ", " << size << ")"); + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + wgpu::Device device = webgpu_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -1473,16 +1476,20 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); + static std::atomic buffer_count; + int buffer_id = buffer_count++; + std::string buf_name = "tensor_buf" + std::to_string(buffer_id); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, - "allocated_buffer"); + buf_name.c_str()); - ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); + ggml_backend_webgpu_buffer_context * buf_ctx = + new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } @@ -2129,6 +2136,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); @@ -2146,6 +2162,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); }); + dev_desc.nextInChain = &deviceTogglesDesc; ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { @@ -2243,11 +2260,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; + const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; + + wgpu::DawnTogglesDescriptor instanceTogglesDesc; + instanceTogglesDesc.enabledToggles = instanceEnabledToggles; + instanceTogglesDesc.enabledToggleCount = 1; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + instance_descriptor.nextInChain = &instanceTogglesDesc; + + webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { From 2560412ee3fc141ea4f2f2e73f865e8539660d94 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 17 Oct 2025 11:42:03 +0800 Subject: [PATCH 09/40] Intermediate state --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 16 +- .../wgsl-shaders/mul_mat.tmpl.wgsl | 4 +- .../wgsl-shaders/mul_mat_fast.wgsl | 181 ++++++++++++++++++ 3 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b4558a9e3f1d2..6e97902de1a7f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,6 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; + webgpu_pipeline mul_mat_fast_pipeline; webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; @@ -855,9 +856,20 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; + const uint32_t M = dst->ne[1]; // number of rows in result + const uint32_t N = dst->ne[0]; // number of columns in result + + webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && src0->ne[0] % 4 == 0) { + pipeline = ctx->mul_mat_fast_pipeline; + uint32_t tiles_x = (M + 64 - 1) / 64; // rows + uint32_t tiles_y = (N + 32 - 1) / 32; // columns + wg_x = tiles_x * tiles_y * dst->ne[2] * dst->ne[3]; + } + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1617,6 +1629,8 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_fast_pipeline, wgsl_mul_mat_fast, "mul_mat_fast"); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 141db9b39d957..6cc1a4d0af888 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -864,8 +864,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl new file mode 100644 index 0000000000000..aa38bfe862f23 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl @@ -0,0 +1,181 @@ +enable f16; + +const WORKGROUP_SIZE_X: u32 = 16u; +const WORKGROUP_SIZE_Y: u32 = 8u; +const TOTAL_WORKGROUP_SIZE: u32 = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; +const TILE_X: u32 = 4u; +const TILE_Y: u32 = 4u; +const TILE_K: u32 = 32u; +const VECTOR_SIZE: u32 = 4u; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array>; // N rows, K columns +@group(0) @binding(1) var src1: array>; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array>; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +var A_shared: array, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/VECTOR_SIZE>; +var B_shared: array, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/VECTOR_SIZE>; + +fn get_local_x(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_Y; +} +fn get_local_y(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_Y; +} + +fn compute_vec4(a: vec4, b: vec4) -> f32 { + let a_f32 = vec4(f32(a.x), f32(a.y), f32(a.z), f32(a.w)); + return dot(a_f32, b); +} + +//override const WORKGROUP_SIZE_X: u32; +//override const WORKGROUP_SIZE_Y: u32; +//override const TILE_X: u32; +//override const TILE_Y: u32; +//override const TILE_K: u32; +//override const VECTOR_SIZE: u32; + +//const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_x = get_local_x(thread_id); + let local_y = get_local_y(thread_id); + + let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; + + let wg_x_count = (params.m + WORKGROUP_SIZE_X * TILE_X - 1u) / (WORKGROUP_SIZE_X * TILE_X); + let wg_y_count = (params.n + WORKGROUP_SIZE_Y * TILE_Y - 1u) / (WORKGROUP_SIZE_Y * TILE_Y); + let wg_per_matrix = wg_x_count * wg_y_count; + + let batch_idx = wg_linear / wg_per_matrix; + + let wg_in_batch = wg_linear % wg_per_matrix; + let wg_y = wg_in_batch % wg_y_count; + let wg_x = wg_in_batch / wg_y_count; + + let output_row_base = wg_x * WORKGROUP_SIZE_X * TILE_X + local_x * TILE_X; + let output_col_base = wg_y * WORKGROUP_SIZE_Y * TILE_Y + local_y * TILE_Y; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + var acc: array, TILE_X>; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + let a_tile_size = TILE_K * WORKGROUP_SIZE_Y * TILE_Y; + let a_loads_per_thread = ((a_tile_size + TOTAL_WORKGROUP_SIZE - 1u) / TOTAL_WORKGROUP_SIZE) / 4; + + for (var load_idx = 0u; load_idx < a_loads_per_thread; load_idx++) { + let elem_idx = ((thread_id / 4) * 4) + (thread_id % 4) * TOTAL_WORKGROUP_SIZE + load_idx * 4 * TOTAL_WORKGROUP_SIZE; +// if (elem_idx < a_tile_size) { + let tile_col = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; + let global_k = k_outer + tile_k; + +// if (global_col < params.n && global_k < params.k) { + let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; + A_shared[elem_idx/4] = src0[src0_idx/4]; +// } +// } + } + + let b_tile_size = WORKGROUP_SIZE_X * TILE_X * TILE_K; + let b_loads_per_thread = ((b_tile_size + TOTAL_WORKGROUP_SIZE - 1u) / TOTAL_WORKGROUP_SIZE) / 4; + + for (var load_idx = 0u; load_idx < b_loads_per_thread; load_idx++) { + let elem_idx = ((thread_id / 4) * 4) + (thread_id % 4) * TOTAL_WORKGROUP_SIZE + load_idx * 4 * TOTAL_WORKGROUP_SIZE; +// if (elem_idx < b_tile_size) { + let tile_row = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; + let global_k = k_outer + tile_k; + +// if (global_row < params.m && global_k < params.k) { + let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; + B_shared[elem_idx/4] = src1[src1_idx/4]; +// } +// } + } + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end / 4u; k_inner++) { + var a_r_tile: array, TILE_Y>; + for (var ty = 0u; ty < TILE_Y; ty++) { + let a_col = local_y * TILE_Y + ty; +// if (output_col_base + ty < params.n) { + let a_idx = k_inner * 4 + a_col * TILE_K; + a_r_tile[ty] = A_shared[a_idx/4]; +// } + } + for (var tx = 0u; tx < TILE_X; tx++) { + let b_row = local_x * TILE_X + tx; +// if (output_row_base + tx < params.m) { + let b_idx = b_row * TILE_K + k_inner * 4u; + let b_vec = B_shared[b_idx/4]; + + for (var ty = 0u; ty < TILE_Y; ty++) { +// if (output_col_base + ty < params.n) { + acc[tx][ty] += compute_vec4(a_r_tile[ty], b_vec); +// } + } +// } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tx = 0u; tx < TILE_X; tx++) { + let global_row = output_row_base + tx; + if (global_row < params.m) { + for (var ty = 0u; ty < TILE_Y; ty += 4) { + let global_col = output_col_base + ty; + if (global_col < params.n) { + let dst_idx = dst_batch_offset + global_row * params.n + global_col; + dst[dst_idx/4] = vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); + } + } + } + } +} \ No newline at end of file From 833c4a86e1d23f48d3c6dad1620a239970bc3ed4 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 19 Oct 2025 18:33:16 +0800 Subject: [PATCH 10/40] Fast working f16/f32 vec4 --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 43 ++++- .../wgsl-shaders/mul_mat_fast.wgsl | 166 +++++++++--------- 2 files changed, 115 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6e97902de1a7f..ec396ba871d66 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -73,6 +73,17 @@ // For operations which process a row in parallel, this seems like a reasonable default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 +// Matrix multiplication fast path parameters + +// Warning: must match values in mul_mat_fast.wgsl +#define WEBGPU_MUL_MAT_TILE_X 4 +#define WEBGPU_MUL_MAT_TILE_Y 4 + +#define WEBGPU_MUL_MAT_WG_SIZE_X 16 +#define WEBGPU_MUL_MAT_WG_SIZE_Y 8 +#define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_VEC_SIZE 4 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -856,17 +867,19 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; - const uint32_t M = dst->ne[1]; // number of rows in result - const uint32_t N = dst->ne[0]; // number of columns in result - webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && src0->ne[0] % 4 == 0) { - pipeline = ctx->mul_mat_fast_pipeline; - uint32_t tiles_x = (M + 64 - 1) / 64; // rows - uint32_t tiles_y = (N + 32 - 1) / 32; // columns - wg_x = tiles_x * tiles_y * dst->ne[2] * dst->ne[3]; + + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && + dst->ne[1] % 4 == 0) { + pipeline = ctx->mul_mat_fast_pipeline; + uint32_t tile_x_s = WEBGPU_MUL_MAT_TILE_X * WEBGPU_MUL_MAT_WG_SIZE_X; + uint32_t tiles_x = (dst->ne[1] + tile_x_s - 1) / tile_x_s; + uint32_t tile_y_s = WEBGPU_MUL_MAT_TILE_Y * WEBGPU_MUL_MAT_WG_SIZE_Y; + uint32_t tiles_y = (dst->ne[0] + tile_y_s - 1) / tile_y_s; + wg_x = tiles_x * tiles_y * dst->ne[2] * dst->ne[3]; } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -1630,7 +1643,19 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_fast_pipeline, wgsl_mul_mat_fast, "mul_mat_fast"); + // override constants + std::vector mul_mat_fast_constants(4); + mul_mat_fast_constants[0].key = "WORKGROUP_SIZE_X"; + mul_mat_fast_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_X; + mul_mat_fast_constants[1].key = "WORKGROUP_SIZE_Y"; + mul_mat_fast_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_Y; + mul_mat_fast_constants[2].key = "TILE_K"; + mul_mat_fast_constants[2].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_fast_constants[3].key = "VEC_SIZE"; + mul_mat_fast_constants[3].value = WEBGPU_MUL_MAT_VEC_SIZE; + + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_fast_pipeline, wgsl_mul_mat_fast, + "mul_mat_fast", mul_mat_fast_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl index aa38bfe862f23..51c110a5d62bb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl @@ -1,13 +1,5 @@ enable f16; -const WORKGROUP_SIZE_X: u32 = 16u; -const WORKGROUP_SIZE_Y: u32 = 8u; -const TOTAL_WORKGROUP_SIZE: u32 = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; -const TILE_X: u32 = 4u; -const TILE_Y: u32 = 4u; -const TILE_K: u32 = 32u; -const VECTOR_SIZE: u32 = 4u; - struct MulMatParams { offset_src0: u32, offset_src1: u32, @@ -33,8 +25,8 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; -var A_shared: array, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/VECTOR_SIZE>; -var B_shared: array, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/VECTOR_SIZE>; +var src0_shmem: array, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/VEC_SIZE>; +var src1_shmem: array, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/VEC_SIZE>; fn get_local_x(thread_id: u32) -> u32 { return thread_id / WORKGROUP_SIZE_Y; @@ -43,19 +35,39 @@ fn get_local_y(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_Y; } +fn zero_vec4_f32() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn zero_vec4_f16() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + fn compute_vec4(a: vec4, b: vec4) -> f32 { let a_f32 = vec4(f32(a.x), f32(a.y), f32(a.z), f32(a.w)); return dot(a_f32, b); } -//override const WORKGROUP_SIZE_X: u32; -//override const WORKGROUP_SIZE_Y: u32; -//override const TILE_X: u32; -//override const TILE_Y: u32; -//override const TILE_K: u32; -//override const VECTOR_SIZE: u32; +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { + return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); +} + +// Warning: cannot be overrides, must match values in ggml-webgpu.cpp +const TILE_X = 4u; +// must be multiple of 4 for vec4 loads +const TILE_Y = 4u; -//const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; +override WORKGROUP_SIZE_X: u32; +override WORKGROUP_SIZE_Y: u32; +override TILE_K: u32; +override VEC_SIZE: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_Y * TILE_Y; +override TILE_SRC1_SHMEM = WORKGROUP_SIZE_X * TILE_X * TILE_K; +// assumes WG_SIZE divides SHMEM TILES +override TILE_SRC0_LD_PER_THREAD = (TILE_SRC0_SHMEM + TOTAL_WORKGROUP_SIZE * VEC_SIZE - 1) / (TOTAL_WORKGROUP_SIZE * VEC_SIZE); +override TILE_SRC1_LD_PER_THREAD = (TILE_SRC1_SHMEM + TOTAL_WORKGROUP_SIZE * VEC_SIZE - 1) / (TOTAL_WORKGROUP_SIZE * VEC_SIZE); @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -90,92 +102,76 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src02_idx = dst2_idx / params.broadcast2; let src12_idx = dst2_idx; - var acc: array, TILE_X>; - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + var acc: array, TILE_X>; + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - let a_tile_size = TILE_K * WORKGROUP_SIZE_Y * TILE_Y; - let a_loads_per_thread = ((a_tile_size + TOTAL_WORKGROUP_SIZE - 1u) / TOTAL_WORKGROUP_SIZE) / 4; - - for (var load_idx = 0u; load_idx < a_loads_per_thread; load_idx++) { - let elem_idx = ((thread_id / 4) * 4) + (thread_id % 4) * TOTAL_WORKGROUP_SIZE + load_idx * 4 * TOTAL_WORKGROUP_SIZE; -// if (elem_idx < a_tile_size) { - let tile_col = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; - let global_k = k_outer + tile_k; - -// if (global_col < params.n && global_k < params.k) { - let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; - A_shared[elem_idx/4] = src0[src0_idx/4]; -// } -// } + for (var load_idx = 0u; load_idx < TILE_SRC0_LD_PER_THREAD; load_idx++) { + let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * VEC_SIZE; + let tile_col = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; + let global_k = k_outer + tile_k; + let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; + src0_shmem[elem_idx/VEC_SIZE] = select( // taking a slight performance hit to avoid oob + zero_vec4_f16(), + src0[src0_idx/VEC_SIZE], + global_col < params.n && global_k < params.k); } - let b_tile_size = WORKGROUP_SIZE_X * TILE_X * TILE_K; - let b_loads_per_thread = ((b_tile_size + TOTAL_WORKGROUP_SIZE - 1u) / TOTAL_WORKGROUP_SIZE) / 4; - - for (var load_idx = 0u; load_idx < b_loads_per_thread; load_idx++) { - let elem_idx = ((thread_id / 4) * 4) + (thread_id % 4) * TOTAL_WORKGROUP_SIZE + load_idx * 4 * TOTAL_WORKGROUP_SIZE; -// if (elem_idx < b_tile_size) { - let tile_row = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; - let global_k = k_outer + tile_k; - -// if (global_row < params.m && global_k < params.k) { - let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; - B_shared[elem_idx/4] = src1[src1_idx/4]; -// } -// } + for (var load_idx = 0u; load_idx < TILE_SRC1_LD_PER_THREAD; load_idx++) { + let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * VEC_SIZE; + let tile_row = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; + let global_k = k_outer + tile_k; + + let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; + src1_shmem[elem_idx/VEC_SIZE] = select( + zero_vec4_f32(), + src1[src1_idx/VEC_SIZE], + global_row < params.m && global_k < params.k); } workgroupBarrier(); let k_end = min(TILE_K, params.k - k_outer); - for (var k_inner = 0u; k_inner < k_end / 4u; k_inner++) { - var a_r_tile: array, TILE_Y>; - for (var ty = 0u; ty < TILE_Y; ty++) { - let a_col = local_y * TILE_Y + ty; -// if (output_col_base + ty < params.n) { - let a_idx = k_inner * 4 + a_col * TILE_K; - a_r_tile[ty] = A_shared[a_idx/4]; -// } - } - for (var tx = 0u; tx < TILE_X; tx++) { - let b_row = local_x * TILE_X + tx; -// if (output_row_base + tx < params.m) { - let b_idx = b_row * TILE_K + k_inner * 4u; - let b_vec = B_shared[b_idx/4]; - - for (var ty = 0u; ty < TILE_Y; ty++) { -// if (output_col_base + ty < params.n) { - acc[tx][ty] += compute_vec4(a_r_tile[ty], b_vec); -// } - } -// } - } + for (var k_inner = 0u; k_inner < k_end; k_inner += VEC_SIZE) { + var src0_tile: array, TILE_Y>; + for (var ty = 0u; ty < TILE_Y; ty++) { + let src0_col = local_y * TILE_Y + ty; + let src0_idx = k_inner + src0_col * TILE_K; + src0_tile[ty] = src0_shmem[src0_idx/VEC_SIZE]; + } + for (var tx = 0u; tx < TILE_X; tx++) { + let src1_row = local_x * TILE_X + tx; + let src1_idx = src1_row * TILE_K + k_inner; + let src1_vec = src1_shmem[src1_idx/VEC_SIZE]; + for (var ty = 0u; ty < TILE_Y; ty++) { + acc[tx][ty] += compute_vec4(src0_tile[ty], src1_vec); } + } + } workgroupBarrier(); } - let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - for (var tx = 0u; tx < TILE_X; tx++) { - let global_row = output_row_base + tx; - if (global_row < params.m) { - for (var ty = 0u; ty < TILE_Y; ty += 4) { - let global_col = output_col_base + ty; - if (global_col < params.n) { - let dst_idx = dst_batch_offset + global_row * params.n + global_col; - dst[dst_idx/4] = vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); - } + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tx = 0u; tx < TILE_X; tx++) { + let global_row = output_row_base + tx; + if (global_row < params.m) { + for (var ty = 0u; ty < TILE_Y; ty += VEC_SIZE) { + let global_col = output_col_base + ty; + if (global_col < params.n) { + let dst_idx = dst_batch_offset + global_row * params.n + global_col; + dst[dst_idx/VEC_SIZE] = store_val(acc, tx, ty); } } } -} \ No newline at end of file + } +} From bd380915dace74c74bd12d5eecd18a6b197556e0 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 20 Oct 2025 20:19:06 +0700 Subject: [PATCH 11/40] Working float fast mul mat --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 73 +++++- ...l_mat_fast.wgsl => mul_mat_fast.tmpl.wgsl} | 218 ++++++++++++++---- 2 files changed, 240 insertions(+), 51 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_fast.wgsl => mul_mat_fast.tmpl.wgsl} (53%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ec396ba871d66..a54d29fc66b5d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -82,7 +83,6 @@ #define WEBGPU_MUL_MAT_WG_SIZE_X 16 #define WEBGPU_MUL_MAT_WG_SIZE_Y 8 #define WEBGPU_MUL_MAT_TILE_K 32 -#define WEBGPU_MUL_MAT_VEC_SIZE 4 /* End Constants */ @@ -258,8 +258,10 @@ struct webgpu_context_struct { webgpu_buf_pool set_rows_error_buf_pool; webgpu_pipeline memset_pipeline; + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline mul_mat_fast_pipeline; webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; @@ -358,6 +360,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } +static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code; + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + if (constants.size() > 0) { + pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constantCount = constants.size(); + } + return { device.CreateComputePipeline(&pipeline_desc), label }; +} + static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::Buffer & buffer, size_t size, @@ -872,9 +898,28 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && - dst->ne[1] % 4 == 0) { - pipeline = ctx->mul_mat_fast_pipeline; + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t tile_x_s = WEBGPU_MUL_MAT_TILE_X * WEBGPU_MUL_MAT_WG_SIZE_X; uint32_t tiles_x = (dst->ne[1] + tile_x_s - 1) / tile_x_s; uint32_t tile_y_s = WEBGPU_MUL_MAT_TILE_Y * WEBGPU_MUL_MAT_WG_SIZE_Y; @@ -1644,18 +1689,26 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); // override constants - std::vector mul_mat_fast_constants(4); + std::vector mul_mat_fast_constants(3); mul_mat_fast_constants[0].key = "WORKGROUP_SIZE_X"; mul_mat_fast_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_X; mul_mat_fast_constants[1].key = "WORKGROUP_SIZE_Y"; mul_mat_fast_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_Y; mul_mat_fast_constants[2].key = "TILE_K"; mul_mat_fast_constants[2].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_fast_constants[3].key = "VEC_SIZE"; - mul_mat_fast_constants[3].value = WEBGPU_MUL_MAT_VEC_SIZE; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_fast_pipeline, wgsl_mul_mat_fast, - "mul_mat_fast", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f32_f32, "mul_mat_fast_f32_f32", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f32_f32_vec, "mul_mat_fast_f32_f32_vec", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f16_f32, "mul_mat_fast_f16_f32", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f16_f32_vec, "mul_mat_fast_f16_f32_vec", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f16_f16, "mul_mat_fast_f16_f16", mul_mat_fast_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_fast_f16_f16_vec, "mul_mat_fast_f16_f16_vec", mul_mat_fast_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl similarity index 53% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl index 51c110a5d62bb..9ade05d09e139 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl @@ -1,3 +1,154 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F32_VEC", "SRC1_F32_VEC"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F32", "SRC1_F32"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F16_VEC", "SRC1_F32_VEC"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F16", "SRC1_F32"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F16_VEC", "SRC1_F16_VEC"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F16", "SRC1_F16"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(SRC0_F32_VEC) +fn zero_val_src0() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} +#enddecl(SRC0_F32_VEC) + +#decl(SRC0_F32) +fn zero_val_src0() -> f32 { + return 0.0; +} +#enddecl(SRC0_F32) + +#decl(SRC0_F16_VEC) +fn zero_val_src0() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} +#enddecl(SRC0_F16_VEC) + +#decl(SRC0_F16) +fn zero_val_src0() -> f16 { + return 0.0; +} +#enddecl(SRC0_F16) + +#decl(SRC1_F32_VEC) +fn zero_val_src1() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { + return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { + return dot(vec4(src0_val), src1_val); +} +#enddecl(SRC1_F32_VEC) + +#decl(SRC1_F32) +fn zero_val_src1() -> f32 { + return 0.0; +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { + return acc[tx][ty]; +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { + return f32(src0_val) * src1_val; +} +#enddecl(SRC1_F32) + +#decl(SRC1_F16_VEC) +fn zero_val_src1() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { + return vec4(acc[tx][ty], f32(acc[tx][ty + 1]), f32(acc[tx][ty + 2]), f32(acc[tx][ty + 3])); +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { + return dot(vec4(src0_val), vec4(src1_val)); +} +#enddecl(SRC1_F16_VEC) + +#decl(SRC1_F16) +fn zero_val_src1() -> f16 { + return 0.0; +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { + return acc[tx][ty]; +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f16) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#enddecl(SRC1_F16) + +#end(DECLS) + +#define(SHADER) enable f16; struct MulMatParams { @@ -19,14 +170,16 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array>; // N rows, K columns -@group(0) @binding(1) var src1: array>; // M rows, K columns (transposed) -@group(0) @binding(2) var dst: array>; // M rows, N columns +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; -var src0_shmem: array, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/VEC_SIZE>; -var src1_shmem: array, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/VEC_SIZE>; +var src0_shmem: array<{{SRC0_TYPE}}, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/{{VEC_SIZE}}>; +var src1_shmem: array<{{SRC1_TYPE}}, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/{{VEC_SIZE}}>; + +DECLS fn get_local_x(thread_id: u32) -> u32 { return thread_id / WORKGROUP_SIZE_Y; @@ -35,23 +188,6 @@ fn get_local_y(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_Y; } -fn zero_vec4_f32() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} - -fn zero_vec4_f16() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} - -fn compute_vec4(a: vec4, b: vec4) -> f32 { - let a_f32 = vec4(f32(a.x), f32(a.y), f32(a.z), f32(a.w)); - return dot(a_f32, b); -} - -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { - return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); -} - // Warning: cannot be overrides, must match values in ggml-webgpu.cpp const TILE_X = 4u; // must be multiple of 4 for vec4 loads @@ -60,14 +196,12 @@ const TILE_Y = 4u; override WORKGROUP_SIZE_X: u32; override WORKGROUP_SIZE_Y: u32; override TILE_K: u32; -override VEC_SIZE: u32; override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_Y * TILE_Y; override TILE_SRC1_SHMEM = WORKGROUP_SIZE_X * TILE_X * TILE_K; -// assumes WG_SIZE divides SHMEM TILES -override TILE_SRC0_LD_PER_THREAD = (TILE_SRC0_SHMEM + TOTAL_WORKGROUP_SIZE * VEC_SIZE - 1) / (TOTAL_WORKGROUP_SIZE * VEC_SIZE); -override TILE_SRC1_LD_PER_THREAD = (TILE_SRC1_SHMEM + TOTAL_WORKGROUP_SIZE * VEC_SIZE - 1) / (TOTAL_WORKGROUP_SIZE * VEC_SIZE); +override TILE_SRC0_LD_PER_THREAD = (TILE_SRC0_SHMEM + TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}} - 1) / (TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}); +override TILE_SRC1_LD_PER_THREAD = (TILE_SRC1_SHMEM + TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}} - 1) / (TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}); @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -110,29 +244,29 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { for (var load_idx = 0u; load_idx < TILE_SRC0_LD_PER_THREAD; load_idx++) { - let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * VEC_SIZE; + let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * {{VEC_SIZE}}; let tile_col = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; let global_k = k_outer + tile_k; let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; - src0_shmem[elem_idx/VEC_SIZE] = select( // taking a slight performance hit to avoid oob - zero_vec4_f16(), - src0[src0_idx/VEC_SIZE], + src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob + zero_val_src0(), + src0[src0_idx/{{VEC_SIZE}}], global_col < params.n && global_k < params.k); } for (var load_idx = 0u; load_idx < TILE_SRC1_LD_PER_THREAD; load_idx++) { - let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * VEC_SIZE; + let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * {{VEC_SIZE}}; let tile_row = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; let global_k = k_outer + tile_k; let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; - src1_shmem[elem_idx/VEC_SIZE] = select( - zero_vec4_f32(), - src1[src1_idx/VEC_SIZE], + src1_shmem[elem_idx/{{VEC_SIZE}}] = select( + zero_val_src1(), + src1[src1_idx/{{VEC_SIZE}}], global_row < params.m && global_k < params.k); } @@ -140,19 +274,19 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let k_end = min(TILE_K, params.k - k_outer); - for (var k_inner = 0u; k_inner < k_end; k_inner += VEC_SIZE) { - var src0_tile: array, TILE_Y>; + for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { + var src0_tile: array<{{SRC0_TYPE}}, TILE_Y>; for (var ty = 0u; ty < TILE_Y; ty++) { let src0_col = local_y * TILE_Y + ty; let src0_idx = k_inner + src0_col * TILE_K; - src0_tile[ty] = src0_shmem[src0_idx/VEC_SIZE]; + src0_tile[ty] = src0_shmem[src0_idx/{{VEC_SIZE}}]; } for (var tx = 0u; tx < TILE_X; tx++) { let src1_row = local_x * TILE_X + tx; let src1_idx = src1_row * TILE_K + k_inner; - let src1_vec = src1_shmem[src1_idx/VEC_SIZE]; + let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; for (var ty = 0u; ty < TILE_Y; ty++) { - acc[tx][ty] += compute_vec4(src0_tile[ty], src1_vec); + acc[tx][ty] += mul_acc(src0_tile[ty], src1_vec); } } } @@ -165,13 +299,15 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var tx = 0u; tx < TILE_X; tx++) { let global_row = output_row_base + tx; if (global_row < params.m) { - for (var ty = 0u; ty < TILE_Y; ty += VEC_SIZE) { + for (var ty = 0u; ty < TILE_Y; ty += {{VEC_SIZE}}) { let global_col = output_col_base + ty; if (global_col < params.n) { let dst_idx = dst_batch_offset + global_row * params.n + global_col; - dst[dst_idx/VEC_SIZE] = store_val(acc, tx, ty); + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tx, ty); } } } } } + +#end(SHADER) From f808c48f0ce0ca59d2354a29c4b3602af642110e Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 22 Oct 2025 18:23:51 +0700 Subject: [PATCH 12/40] Clean up naming of mul_mat to match logical model, start work on q mul_mat --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 +-- .../wgsl-shaders/mul_mat.tmpl.wgsl | 6 +- .../wgsl-shaders/mul_mat_fast.tmpl.wgsl | 139 ++++++------ .../wgsl-shaders/mul_mat_q.tmpl.wgsl | 214 ++++++++++++++++++ 4 files changed, 300 insertions(+), 89 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a54d29fc66b5d..efc2eb09cc264 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -77,11 +77,11 @@ // Matrix multiplication fast path parameters // Warning: must match values in mul_mat_fast.wgsl -#define WEBGPU_MUL_MAT_TILE_X 4 -#define WEBGPU_MUL_MAT_TILE_Y 4 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 2 -#define WEBGPU_MUL_MAT_WG_SIZE_X 16 -#define WEBGPU_MUL_MAT_WG_SIZE_Y 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 16 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 /* End Constants */ @@ -863,8 +863,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[1], // number of rows in result (M) - (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) dst->ne[0], // number of rows in result (M, transposed) + (uint32_t) dst->ne[1], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 @@ -920,11 +920,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (use_fast) { int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t tile_x_s = WEBGPU_MUL_MAT_TILE_X * WEBGPU_MUL_MAT_WG_SIZE_X; - uint32_t tiles_x = (dst->ne[1] + tile_x_s - 1) / tile_x_s; - uint32_t tile_y_s = WEBGPU_MUL_MAT_TILE_Y * WEBGPU_MUL_MAT_WG_SIZE_Y; - uint32_t tiles_y = (dst->ne[0] + tile_y_s - 1) / tile_y_s; - wg_x = tiles_x * tiles_y * dst->ne[2] * dst->ne[3]; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + uint32_t tiles_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tiles_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_x = tiles_m * tiles_n * dst->ne[2] * dst->ne[3]; } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -1690,10 +1690,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { // override constants std::vector mul_mat_fast_constants(3); - mul_mat_fast_constants[0].key = "WORKGROUP_SIZE_X"; - mul_mat_fast_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_X; - mul_mat_fast_constants[1].key = "WORKGROUP_SIZE_Y"; - mul_mat_fast_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_Y; + mul_mat_fast_constants[0].key = "WORKGROUP_SIZE_M"; + mul_mat_fast_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_fast_constants[1].key = "WORKGROUP_SIZE_N"; + mul_mat_fast_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_N; mul_mat_fast_constants[2].key = "TILE_K"; mul_mat_fast_constants[2].value = WEBGPU_MUL_MAT_TILE_K; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 6cc1a4d0af888..0f8e6e5ac3dd6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let dst2_rem = dst3_rem % dst2_stride; - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column + let row = dst2_rem / params.m; // output row + let col = dst2_rem % params.m; // output column let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; @@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl index 9ade05d09e139..d7779729dc14a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl @@ -95,8 +95,8 @@ fn zero_val_src1() -> vec4 { return vec4(0.0, 0.0, 0.0, 0.0); } -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { - return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { @@ -109,8 +109,8 @@ fn zero_val_src1() -> f32 { return 0.0; } -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { - return acc[tx][ty]; +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { @@ -123,8 +123,8 @@ fn zero_val_src1() -> vec4 { return vec4(0.0, 0.0, 0.0, 0.0); } -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { - return vec4(acc[tx][ty], f32(acc[tx][ty + 1]), f32(acc[tx][ty + 2]), f32(acc[tx][ty + 3])); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); } fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { @@ -137,8 +137,8 @@ fn zero_val_src1() -> f16 { return 0.0; } -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { - return acc[tx][ty]; +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f16) -> f32 { @@ -170,61 +170,60 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; -var src0_shmem: array<{{SRC0_TYPE}}, (WORKGROUP_SIZE_Y * TILE_Y * TILE_K)/{{VEC_SIZE}}>; -var src1_shmem: array<{{SRC1_TYPE}}, (WORKGROUP_SIZE_X * TILE_X * TILE_K)/{{VEC_SIZE}}>; - DECLS -fn get_local_x(thread_id: u32) -> u32 { - return thread_id / WORKGROUP_SIZE_Y; +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; } -fn get_local_y(thread_id: u32) -> u32 { - return thread_id % WORKGROUP_SIZE_Y; +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; } + // Warning: cannot be overrides, must match values in ggml-webgpu.cpp -const TILE_X = 4u; +const TILE_N = 2u; // must be multiple of 4 for vec4 loads -const TILE_Y = 4u; +const TILE_M = 4u; -override WORKGROUP_SIZE_X: u32; -override WORKGROUP_SIZE_Y: u32; +override WORKGROUP_SIZE_M: u32; +override WORKGROUP_SIZE_N: u32; override TILE_K: u32; -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_Y * TILE_Y; -override TILE_SRC1_SHMEM = WORKGROUP_SIZE_X * TILE_X * TILE_K; -override TILE_SRC0_LD_PER_THREAD = (TILE_SRC0_SHMEM + TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}} - 1) / (TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}); -override TILE_SRC1_LD_PER_THREAD = (TILE_SRC1_SHMEM + TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}} - 1) / (TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}); +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>; +var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3) { let thread_id = local_id.x; - let local_x = get_local_x(thread_id); - let local_y = get_local_y(thread_id); + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let wg_x_count = (params.m + WORKGROUP_SIZE_X * TILE_X - 1u) / (WORKGROUP_SIZE_X * TILE_X); - let wg_y_count = (params.n + WORKGROUP_SIZE_Y * TILE_Y - 1u) / (WORKGROUP_SIZE_Y * TILE_Y); - let wg_per_matrix = wg_x_count * wg_y_count; + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; let batch_idx = wg_linear / wg_per_matrix; let wg_in_batch = wg_linear % wg_per_matrix; - let wg_y = wg_in_batch % wg_y_count; - let wg_x = wg_in_batch / wg_y_count; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; - let output_row_base = wg_x * WORKGROUP_SIZE_X * TILE_X + local_x * TILE_X; - let output_col_base = wg_y * WORKGROUP_SIZE_Y * TILE_Y + local_y * TILE_Y; + let output_row_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + let output_col_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; @@ -239,54 +238,52 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc: array, TILE_X>; + var acc: array, TILE_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - for (var load_idx = 0u; load_idx < TILE_SRC0_LD_PER_THREAD; load_idx++) { - let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * {{VEC_SIZE}}; - let tile_col = elem_idx / TILE_K; + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; + let global_m = wg_m * WORKGROUP_SIZE_M * TILE_M + tile_m; let global_k = k_outer + tile_k; - let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; + let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob - zero_val_src0(), - src0[src0_idx/{{VEC_SIZE}}], - global_col < params.n && global_k < params.k); + zero_val_src0(), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); } - for (var load_idx = 0u; load_idx < TILE_SRC1_LD_PER_THREAD; load_idx++) { - let elem_idx = (thread_id + load_idx * TOTAL_WORKGROUP_SIZE) * {{VEC_SIZE}}; - let tile_row = elem_idx / TILE_K; + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; + let global_n = wg_n * WORKGROUP_SIZE_N * TILE_N + tile_n; let global_k = k_outer + tile_k; - let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; + let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; src1_shmem[elem_idx/{{VEC_SIZE}}] = select( - zero_val_src1(), - src1[src1_idx/{{VEC_SIZE}}], - global_row < params.m && global_k < params.k); - } + zero_val_src1(), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + } workgroupBarrier(); let k_end = min(TILE_K, params.k - k_outer); for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { - var src0_tile: array<{{SRC0_TYPE}}, TILE_Y>; - for (var ty = 0u; ty < TILE_Y; ty++) { - let src0_col = local_y * TILE_Y + ty; - let src0_idx = k_inner + src0_col * TILE_K; - src0_tile[ty] = src0_shmem[src0_idx/{{VEC_SIZE}}]; + var src0_tile: array<{{SRC0_TYPE}}, TILE_M>; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = src0_shmem[src0_idx/{{VEC_SIZE}}]; } - for (var tx = 0u; tx < TILE_X; tx++) { - let src1_row = local_x * TILE_X + tx; - let src1_idx = src1_row * TILE_K + k_inner; + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; - for (var ty = 0u; ty < TILE_Y; ty++) { - acc[tx][ty] += mul_acc(src0_tile[ty], src1_vec); + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += mul_acc(src0_tile[tm], src1_vec); } } } @@ -296,14 +293,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - for (var tx = 0u; tx < TILE_X; tx++) { - let global_row = output_row_base + tx; - if (global_row < params.m) { - for (var ty = 0u; ty < TILE_Y; ty += {{VEC_SIZE}}) { - let global_col = output_col_base + ty; - if (global_col < params.n) { - let dst_idx = dst_batch_offset + global_row * params.n + global_col; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tx, ty); + for (var tn = 0u; tn < TILE_N; tn++) { + let global_row = output_row_base + tn; + if (global_row < params.n) { + for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + let global_col = output_col_base + tm; + if (global_col < params.m) { + let dst_idx = dst_batch_offset + global_row * params.m + global_col; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl new file mode 100644 index 0000000000000..b96f5b3acb10e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl @@ -0,0 +1,214 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "q4_0_vec", + "REPLS": { + "SRC0_TYPE" : "q4_0", + "SRC1_TYPE" : "vec4", + "VEC_SIZE" : "4", + "BLOCK_SIZE": 32, + }, + "DECLS": ["BYTE_HELPERS", "Q4_0_T"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(SRC1_F32_VEC) +fn zero_val_src1() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { + return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { + return dot(vec4(src0_val), src1_val); +} +#enddecl(SRC1_F32_VEC) + +#decl(SRC1_F32) +fn zero_val_src1() -> f32 { + return 0.0; +} + +fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { + return acc[tx][ty]; +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { + return f32(src0_val) * src1_val; +} +#enddecl(SRC1_F32) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns. +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array<{{SRC1_TYPE}}>; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +fn get_local_x(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_Y; +} +fn get_local_y(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_Y; +} + +// Warning: cannot be overrides, must match values in ggml-webgpu.cpp +const TILE_X = 4u; +// must be multiple of 4 for vec4 loads +const TILE_Y = 4u; + +override WORKGROUP_SIZE_X: u32; +override WORKGROUP_SIZE_Y: u32; +override TILE_K: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; +override BLOCKS = max(1, TILE_K/{{BLOCK_SIZE}}); // the number of blocks we need to store at least TILE_K elements per thread. Note that since TILE_K may be less than BLOCK_SIZE, we need at least room for 1 block. Otherwise, TILE_K must be divisible by BLOCK_SIZE, or a clean fraction if it, or things will get weird. +override TILE_SRC0_SHMEM = BLOCKS * WORKGROUP_SIZE_Y * TILE_Y; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_X * TILE_X; + +var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM>; // stores tiles of quantized weights without dequantizing +var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_x = get_local_x(thread_id); + let local_y = get_local_y(thread_id); + + let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; + + let wg_x_count = (params.m + WORKGROUP_SIZE_X * TILE_X - 1u) / (WORKGROUP_SIZE_X * TILE_X); + let wg_y_count = (params.n + WORKGROUP_SIZE_Y * TILE_Y - 1u) / (WORKGROUP_SIZE_Y * TILE_Y); + let wg_per_matrix = wg_x_count * wg_y_count; + + let batch_idx = wg_linear / wg_per_matrix; + + let wg_in_batch = wg_linear % wg_per_matrix; + let wg_y = wg_in_batch % wg_y_count; + let wg_x = wg_in_batch / wg_y_count; + + let output_row_base = wg_x * WORKGROUP_SIZE_X * TILE_X + local_x * TILE_X; + let output_col_base = wg_y * WORKGROUP_SIZE_Y * TILE_Y + local_y * TILE_Y; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + var acc: array, TILE_X>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // we only load new blocks of src0 when the k_outer is at a block boundary + if (k_outer % {{BLOCK_SIZE}} == 0) { + // load src0 tile + // need to figure out mapping from blocks to where they should be stored here + for (var block_idx = thread_id; block_idx < TILE_SRC0_SHMEM; block_idx += TOTAL_WORKGROUP_SIZE) { + let block_col = block_idx / BLOCKS; + let block_start_idx = block_idx * {{BLOCK_SIZE}}; + let tile_col = block_start_idx / TILE_K; + let tile_k = block_start_idx % TILE_K; + let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; + let global_k = k_outer + tile_k; + let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; + src0_shmem[block_idx] = select( // taking a slight performance hit to avoid oob + zero_val_src0(), + src0[src0_idx], + global_col < params.n && global_k < params.k); + } + + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_row = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; + let global_k = k_outer + tile_k; + + let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; + src1_shmem[elem_idx/{{VEC_SIZE}}] = select( + zero_val_src1(), + src1[src1_idx/{{VEC_SIZE}}], + global_row < params.m && global_k < params.k); + } + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { + var src0_tile: array<{{SRC0_TYPE}}, TILE_Y>; + for (var ty = 0u; ty < TILE_Y; ty++) { + let src0_col = local_y * TILE_Y + ty; + let src0_idx = k_inner + src0_col * TILE_K; + src0_tile[ty] = src0_shmem[src0_idx/{{VEC_SIZE}}]; + } + for (var tx = 0u; tx < TILE_X; tx++) { + let src1_row = local_x * TILE_X + tx; + let src1_idx = src1_row * TILE_K + k_inner; + let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; + for (var ty = 0u; ty < TILE_Y; ty++) { + acc[tx][ty] += mul_acc(src0_tile[ty], src1_vec); + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tx = 0u; tx < TILE_X; tx++) { + let global_row = output_row_base + tx; + if (global_row < params.m) { + for (var ty = 0u; ty < TILE_Y; ty += {{VEC_SIZE}}) { + let global_col = output_col_base + ty; + if (global_col < params.n) { + let dst_idx = dst_batch_offset + global_row * params.n + global_col; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tx, ty); + } + } + } + } +} + +#end(SHADER) From a3b2f67676d65d5417bbba448fd3d5e8ea0fde30 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sat, 25 Oct 2025 12:35:07 +0700 Subject: [PATCH 13/40] Setup for subgroup matrix mat mul --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 102 ++++-- .../wgsl-shaders/mul_mat_q.tmpl.wgsl | 214 ------------ ...t.tmpl.wgsl => mul_mat_reg_tile.tmpl.wgsl} | 7 +- .../mul_mat_subgroup_matrix.tmpl.wgsl | 325 ++++++++++++++++++ 4 files changed, 408 insertions(+), 240 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_fast.tmpl.wgsl => mul_mat_reg_tile.tmpl.wgsl} (98%) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index efc2eb09cc264..af52d9ef46df3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -78,12 +78,15 @@ // Warning: must match values in mul_mat_fast.wgsl #define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 2 +#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_WG_SIZE_M 16 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_SUBGROUP_M 1 +#define WEBGPU_MUL_MAT_SUBGROUP_N 1 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -247,6 +250,10 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + bool supports_subgroup_matrix = false; + uint32_t subgroup_size; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; + // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. uint32_t max_wg_size_x; @@ -1689,26 +1696,59 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); // override constants - std::vector mul_mat_fast_constants(3); - mul_mat_fast_constants[0].key = "WORKGROUP_SIZE_M"; - mul_mat_fast_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_fast_constants[1].key = "WORKGROUP_SIZE_N"; - mul_mat_fast_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_N; - mul_mat_fast_constants[2].key = "TILE_K"; - mul_mat_fast_constants[2].value = WEBGPU_MUL_MAT_TILE_K; - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f32_f32, "mul_mat_fast_f32_f32", mul_mat_fast_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f32_f32_vec, "mul_mat_fast_f32_f32_vec", mul_mat_fast_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f16_f32, "mul_mat_fast_f16_f32", mul_mat_fast_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f16_f32_vec, "mul_mat_fast_f16_f32_vec", mul_mat_fast_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f16_f16, "mul_mat_fast_f16_f16", mul_mat_fast_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_fast_f16_f16_vec, "mul_mat_fast_f16_f16_vec", mul_mat_fast_constants); + std::vector mul_mat_opt_constants(3); + mul_mat_opt_constants[0].key = "WORKGROUP_SIZE_M"; + mul_mat_opt_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_opt_constants[1].key = "WORKGROUP_SIZE_N"; + mul_mat_opt_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_N; + mul_mat_opt_constants[2].key = "TILE_K"; + mul_mat_opt_constants[2].value = WEBGPU_MUL_MAT_TILE_K; + + if (webgpu_ctx->supports_subgroup_matrix) { + mul_mat_opt_constants.push_back({ .key = "SUBGROUP_M", .value = WEBGPU_MUL_MAT_SUBGROUP_M }); + mul_mat_opt_constants.push_back( + { .key = "SUBGROUP_MATRIX_M", .value = static_cast(webgpu_ctx->subgroup_matrix_config.M) }); + mul_mat_opt_constants.push_back({ .key = "SUBGROUP_N", .value = WEBGPU_MUL_MAT_SUBGROUP_N }); + mul_mat_opt_constants.push_back( + { .key = "SUBGROUP_MATRIX_N", .value = static_cast(webgpu_ctx->subgroup_matrix_config.N) }); + mul_mat_opt_constants.push_back( + { .key = "SUBGROUP_SIZE", .value = static_cast(webgpu_ctx->subgroup_size) }); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, + "mul_mat_subgroup_matrix_f32_f32", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32_vec, + "mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32, + "mul_mat_subgroup_matrix_f16_f32", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32_vec, + "mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16, + "mul_mat_subgroup_matrix_f16_f16", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec, + "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_opt_constants); + } else { + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, + "mul_mat_reg_tile_f32_f32_vec", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, + "mul_mat_reg_tile_f16_f32_vec", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_opt_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, + "mul_mat_reg_tile_f16_f16_vec", mul_mat_opt_constants); + } } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -2218,12 +2258,30 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + info.nextInChain = &subgroup_matrix_configs; ctx->adapter.GetInfo(&info); + ctx->subgroup_matrix_config = *subgroup_matrix_configs.configs; + wgpu::SupportedFeatures features; + ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + // Only support devices/configurations where subgroup implementations make sense + // (i.e. it won't change on you at runtime) + if (info.subgroupMinSize == info.subgroupMaxSize) { + ctx->subgroup_size = info.subgroupMaxSize; + ctx->supports_subgroup_matrix = ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } + // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; + if (ctx->supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl deleted file mode 100644 index b96f5b3acb10e..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_q.tmpl.wgsl +++ /dev/null @@ -1,214 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "q4_0_vec", - "REPLS": { - "SRC0_TYPE" : "q4_0", - "SRC1_TYPE" : "vec4", - "VEC_SIZE" : "4", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(SRC1_F32_VEC) -fn zero_val_src1() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} - -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> vec4 { - return vec4(acc[tx][ty], acc[tx][ty + 1], acc[tx][ty + 2], acc[tx][ty + 3]); -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { - return dot(vec4(src0_val), src1_val); -} -#enddecl(SRC1_F32_VEC) - -#decl(SRC1_F32) -fn zero_val_src1() -> f32 { - return 0.0; -} - -fn store_val(acc: array, TILE_X>, tx: u32, ty: u32) -> f32 { - return acc[tx][ty]; -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { - return f32(src0_val) * src1_val; -} -#enddecl(SRC1_F32) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns. -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) -@group(0) @binding(2) var dst: array<{{SRC1_TYPE}}>; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -DECLS - -fn get_local_x(thread_id: u32) -> u32 { - return thread_id / WORKGROUP_SIZE_Y; -} -fn get_local_y(thread_id: u32) -> u32 { - return thread_id % WORKGROUP_SIZE_Y; -} - -// Warning: cannot be overrides, must match values in ggml-webgpu.cpp -const TILE_X = 4u; -// must be multiple of 4 for vec4 loads -const TILE_Y = 4u; - -override WORKGROUP_SIZE_X: u32; -override WORKGROUP_SIZE_Y: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y; -override BLOCKS = max(1, TILE_K/{{BLOCK_SIZE}}); // the number of blocks we need to store at least TILE_K elements per thread. Note that since TILE_K may be less than BLOCK_SIZE, we need at least room for 1 block. Otherwise, TILE_K must be divisible by BLOCK_SIZE, or a clean fraction if it, or things will get weird. -override TILE_SRC0_SHMEM = BLOCKS * WORKGROUP_SIZE_Y * TILE_Y; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_X * TILE_X; - -var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM>; // stores tiles of quantized weights without dequantizing -var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; - -@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { - - let thread_id = local_id.x; - let local_x = get_local_x(thread_id); - let local_y = get_local_y(thread_id); - - let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - - let wg_x_count = (params.m + WORKGROUP_SIZE_X * TILE_X - 1u) / (WORKGROUP_SIZE_X * TILE_X); - let wg_y_count = (params.n + WORKGROUP_SIZE_Y * TILE_Y - 1u) / (WORKGROUP_SIZE_Y * TILE_Y); - let wg_per_matrix = wg_x_count * wg_y_count; - - let batch_idx = wg_linear / wg_per_matrix; - - let wg_in_batch = wg_linear % wg_per_matrix; - let wg_y = wg_in_batch % wg_y_count; - let wg_x = wg_in_batch / wg_y_count; - - let output_row_base = wg_x * WORKGROUP_SIZE_X * TILE_X + local_x * TILE_X; - let output_col_base = wg_y * WORKGROUP_SIZE_Y * TILE_Y + local_y * TILE_Y; - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - - var acc: array, TILE_X>; - - for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - - // we only load new blocks of src0 when the k_outer is at a block boundary - if (k_outer % {{BLOCK_SIZE}} == 0) { - // load src0 tile - // need to figure out mapping from blocks to where they should be stored here - for (var block_idx = thread_id; block_idx < TILE_SRC0_SHMEM; block_idx += TOTAL_WORKGROUP_SIZE) { - let block_col = block_idx / BLOCKS; - let block_start_idx = block_idx * {{BLOCK_SIZE}}; - let tile_col = block_start_idx / TILE_K; - let tile_k = block_start_idx % TILE_K; - let global_col = wg_y * WORKGROUP_SIZE_Y * TILE_Y + tile_col; - let global_k = k_outer + tile_k; - let src0_idx = src0_batch_offset + global_col * params.stride_01 + global_k; - src0_shmem[block_idx] = select( // taking a slight performance hit to avoid oob - zero_val_src0(), - src0[src0_idx], - global_col < params.n && global_k < params.k); - } - - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let tile_row = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_row = wg_x * WORKGROUP_SIZE_X * TILE_X + tile_row; - let global_k = k_outer + tile_k; - - let src1_idx = src1_batch_offset + global_row * params.stride_11 + global_k; - src1_shmem[elem_idx/{{VEC_SIZE}}] = select( - zero_val_src1(), - src1[src1_idx/{{VEC_SIZE}}], - global_row < params.m && global_k < params.k); - } - - workgroupBarrier(); - - let k_end = min(TILE_K, params.k - k_outer); - - for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { - var src0_tile: array<{{SRC0_TYPE}}, TILE_Y>; - for (var ty = 0u; ty < TILE_Y; ty++) { - let src0_col = local_y * TILE_Y + ty; - let src0_idx = k_inner + src0_col * TILE_K; - src0_tile[ty] = src0_shmem[src0_idx/{{VEC_SIZE}}]; - } - for (var tx = 0u; tx < TILE_X; tx++) { - let src1_row = local_x * TILE_X + tx; - let src1_idx = src1_row * TILE_K + k_inner; - let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; - for (var ty = 0u; ty < TILE_Y; ty++) { - acc[tx][ty] += mul_acc(src0_tile[ty], src1_vec); - } - } - } - - workgroupBarrier(); - } - - let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - for (var tx = 0u; tx < TILE_X; tx++) { - let global_row = output_row_base + tx; - if (global_row < params.m) { - for (var ty = 0u; ty < TILE_Y; ty += {{VEC_SIZE}}) { - let global_col = output_col_base + ty; - if (global_col < params.n) { - let dst_idx = dst_batch_offset + global_row * params.n + global_col; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tx, ty); - } - } - } - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl similarity index 98% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index d7779729dc14a..5cafa7c4c08a4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_fast.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -172,7 +172,7 @@ struct MulMatParams { @group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns @group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; @@ -185,11 +185,10 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } - // Warning: cannot be overrides, must match values in ggml-webgpu.cpp -const TILE_N = 2u; -// must be multiple of 4 for vec4 loads +// TILE_M must be multiple of 4 for vec4 loads const TILE_M = 4u; +const TILE_N = 4u; override WORKGROUP_SIZE_M: u32; override WORKGROUP_SIZE_N: u32; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl new file mode 100644 index 0000000000000..33076efc1f3b5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -0,0 +1,325 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F32_VEC", "SRC1_F32_VEC"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F32", "SRC1_F32"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F16_VEC", "SRC1_F32_VEC"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F16", "SRC1_F32"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["SRC0_F16_VEC", "SRC1_F16_VEC"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SRC0_F16", "SRC1_F16"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(SRC0_F32_VEC) +fn zero_val_src0() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} +#enddecl(SRC0_F32_VEC) + +#decl(SRC0_F32) +fn zero_val_src0() -> f32 { + return 0.0; +} +#enddecl(SRC0_F32) + +#decl(SRC0_F16_VEC) +fn zero_val_src0() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} +#enddecl(SRC0_F16_VEC) + +#decl(SRC0_F16) +fn zero_val_src0() -> f16 { + return 0.0; +} +#enddecl(SRC0_F16) + +#decl(SRC1_F32_VEC) +fn zero_val_src1() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { + return dot(vec4(src0_val), src1_val); +} +#enddecl(SRC1_F32_VEC) + +#decl(SRC1_F32) +fn zero_val_src1() -> f32 { + return 0.0; +} + +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { + return f32(src0_val) * src1_val; +} +#enddecl(SRC1_F32) + +#decl(SRC1_F16_VEC) +fn zero_val_src1() -> vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); +} + +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { + return dot(vec4(src0_val), vec4(src1_val)); +} +#enddecl(SRC1_F16_VEC) + +#decl(SRC1_F16) +fn zero_val_src1() -> f16 { + return 0.0; +} + +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; +} + +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f16) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#enddecl(SRC1_F16) + +#end(DECLS) + +#define(SHADER) +enable f16; +enable chromium_experimental_subgroup_matrix; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +// Number of threads per workgroup: SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE +// Shared memory src0: SUBGROUP_M * SUBGROUP_MATRIX_M * TILE_K +// Shared memory src1: SUBGROUP_N * SUBGROUP_MATRIX_N * TILE_K +// TILE_K must be divisible by SUBGROUP_MATRIX_K + +override SUBGROUP_M: u32; +override SUBGROUP_MATRIX_M: u32; +override SUBGROUP_N: u32; +override SUBGROUP_MATRIX_N: u32; +override SUBGROUP_SIZE: u32; + +// Warning: cannot be overrides, must match values in ggml-webgpu.cpp +// TILE_M must be multiple of 4 for vec4 loads +const TILE_M = 4u; +const TILE_N = 4u; + +override WORKGROUP_SIZE_M: u32; +override WORKGROUP_SIZE_N: u32; +override TILE_K: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>; +var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_linear / wg_per_matrix; + + let wg_in_batch = wg_linear % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + let output_col_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + var acc: array, TILE_M>; + var src0_sg_mat : subgroup_matrix_left; + var src1_sg_mat : subgroup_matrix_right; + var acc_sg_mat : subgroup_matrix_result; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = wg_m * WORKGROUP_SIZE_M * TILE_M + tile_m; + let global_k = k_outer + tile_k; + let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; + src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob + zero_val_src0(), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); + } + + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_n = wg_n * WORKGROUP_SIZE_N * TILE_N + tile_n; + let global_k = k_outer + tile_k; + + let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; + src1_shmem[elem_idx/{{VEC_SIZE}}] = select( + zero_val_src1(), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + } + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { + var src0_tile: array<{{SRC0_TYPE}}, TILE_M>; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = src0_shmem[src0_idx/{{VEC_SIZE}}]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += mul_acc(src0_tile[tm], src1_vec); + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_row = output_row_base + tn; + if (global_row < params.n) { + for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + let global_col = output_col_base + tm; + if (global_col < params.m) { + let dst_idx = dst_batch_offset + global_row * params.m + global_col; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + } + } + } + } +} + +#end(SHADER) From 0bdd9f47518de891de96eb63a9755e04723cd4f7 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sat, 25 Oct 2025 20:00:27 +0700 Subject: [PATCH 14/40] Basic working subgroup matrix --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 57 +++-- .../mul_mat_subgroup_matrix.tmpl.wgsl | 214 +++++++----------- 2 files changed, 116 insertions(+), 155 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index af52d9ef46df3..e6e317819c67d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -84,8 +84,8 @@ #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 -#define WEBGPU_MUL_MAT_SUBGROUP_M 1 -#define WEBGPU_MUL_MAT_SUBGROUP_N 1 +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 /* End Constants */ @@ -925,13 +925,24 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, } if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - uint32_t tiles_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tiles_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; - wg_x = tiles_m * tiles_n * dst->ne[2] * dst->ne[3]; + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; + + uint32_t wg_m; + uint32_t wg_n; + if (ctx->supports_subgroup_matrix) { + // The total number of subgroups/workgroups needed per matrix. + uint32_t subgroups_m = (dst->ne[0] + ctx->subgroup_matrix_config.M - 1) / ctx->subgroup_matrix_config.M; + wg_m = subgroups_m / WEBGPU_MUL_MAT_SUBGROUP_M; + uint32_t subgroups_n = (dst->ne[1] + ctx->subgroup_matrix_config.N - 1) / ctx->subgroup_matrix_config.N; + wg_n = subgroups_n / WEBGPU_MUL_MAT_SUBGROUP_N; + } else { + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -1696,13 +1707,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); // override constants - std::vector mul_mat_opt_constants(3); - mul_mat_opt_constants[0].key = "WORKGROUP_SIZE_M"; - mul_mat_opt_constants[0].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_opt_constants[1].key = "WORKGROUP_SIZE_N"; - mul_mat_opt_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_N; - mul_mat_opt_constants[2].key = "TILE_K"; - mul_mat_opt_constants[2].value = WEBGPU_MUL_MAT_TILE_K; + std::vector mul_mat_opt_constants(1); + mul_mat_opt_constants[0].key = "TILE_K"; + mul_mat_opt_constants[0].value = WEBGPU_MUL_MAT_TILE_K; if (webgpu_ctx->supports_subgroup_matrix) { mul_mat_opt_constants.push_back({ .key = "SUBGROUP_M", .value = WEBGPU_MUL_MAT_SUBGROUP_M }); @@ -1713,6 +1720,8 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { { .key = "SUBGROUP_MATRIX_N", .value = static_cast(webgpu_ctx->subgroup_matrix_config.N) }); mul_mat_opt_constants.push_back( { .key = "SUBGROUP_SIZE", .value = static_cast(webgpu_ctx->subgroup_size) }); + mul_mat_opt_constants.push_back( + { .key = "SUBGROUP_MATRIX_K", .value = static_cast(webgpu_ctx->subgroup_matrix_config.K) }); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, @@ -1733,6 +1742,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec, "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_opt_constants); } else { + mul_mat_opt_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_opt_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_opt_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = @@ -2268,12 +2280,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); - // Only support devices/configurations where subgroup implementations make sense - // (i.e. it won't change on you at runtime) - if (info.subgroupMinSize == info.subgroupMaxSize) { - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - } + + // For subgroup matrix code to be workable, we really need a consistent subgroup size. + // Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices + // where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values. + // Therefore, hardcoding the subgroup size to 32 for now for development. + ctx->subgroup_size = 32; + ctx->supports_subgroup_matrix = ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 33076efc1f3b5..aedbd61cfa9d1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -5,10 +5,10 @@ "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", + "DST_TYPE" : "f32", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F32_VEC", "SRC1_F32_VEC"] + "DECLS": ["SHMEM_VEC"] }, { "SHADER_SUFFIX": "f32_f32", @@ -18,17 +18,17 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F32", "SRC1_F32"] + "DECLS": ["SHMEM_SCALAR"] }, { "SHADER_SUFFIX": "f16_f32_vec", "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", + "DST_TYPE" : "f32", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F16_VEC", "SRC1_F32_VEC"] + "DECLS": ["SHMEM_VEC"] }, { "SHADER_SUFFIX": "f16_f32", @@ -38,17 +38,17 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F16", "SRC1_F32"] + "DECLS": ["SHMEM_SCALAR"] }, { "SHADER_SUFFIX": "f16_f16_vec", "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", + "DST_TYPE" : "f32", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F16_VEC", "SRC1_F16_VEC"] + "DECLS": ["SHMEM_VEC"] }, { "SHADER_SUFFIX": "f16_f16", @@ -58,7 +58,7 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F16", "SRC1_F16"] + "DECLS": ["SHMEM_SCALAR"] } ] @@ -66,89 +66,52 @@ #define(DECLS) -#decl(SRC0_F32_VEC) -fn zero_val_src0() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} -#enddecl(SRC0_F32_VEC) - -#decl(SRC0_F32) -fn zero_val_src0() -> f32 { - return 0.0; +#decl(SHMEM_VEC) +fn zero_val_src0() -> {{SRC0_TYPE}} { + return {{SRC0_TYPE}}(0.0, 0.0, 0.0, 0.0); } -#enddecl(SRC0_F32) -#decl(SRC0_F16_VEC) -fn zero_val_src0() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); +fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { + src0_shmem[idx] = f32(val.x); + src0_shmem[idx + 1] = f32(val.y); + src0_shmem[idx + 2] = f32(val.z); + src0_shmem[idx + 3] = f32(val.w); } -#enddecl(SRC0_F16_VEC) -#decl(SRC0_F16) -fn zero_val_src0() -> f16 { - return 0.0; +fn zero_val_src1() -> {{SRC1_TYPE}} { + return {{SRC1_TYPE}}(0.0, 0.0, 0.0, 0.0); } -#enddecl(SRC0_F16) -#decl(SRC1_F32_VEC) -fn zero_val_src1() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); +fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { + src1_shmem[idx] = f32(val.x); + src1_shmem[idx + 1] = f32(val.y); + src1_shmem[idx + 2] = f32(val.z); + src1_shmem[idx + 3] = f32(val.w); } +#enddecl(SHMEM_VEC) -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { - return dot(vec4(src0_val), src1_val); -} -#enddecl(SRC1_F32_VEC) - -#decl(SRC1_F32) -fn zero_val_src1() -> f32 { +#decl(SHMEM_SCALAR) +fn zero_val_src0() -> {{SRC0_TYPE}} { return 0.0; } -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return acc[tm][tn]; -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { - return f32(src0_val) * src1_val; -} -#enddecl(SRC1_F32) - -#decl(SRC1_F16_VEC) -fn zero_val_src1() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); +fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { + src0_shmem[idx] = f32(val); } -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(acc[tm][tn], f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { - return dot(vec4(src0_val), vec4(src1_val)); -} -#enddecl(SRC1_F16_VEC) - -#decl(SRC1_F16) -fn zero_val_src1() -> f16 { +fn zero_val_src1() -> {{SRC1_TYPE}} { return 0.0; } -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return acc[tm][tn]; +fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { + src1_shmem[idx] = f32(val); } - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f16) -> f32 { - return f32(src0_val) * f32(src1_val); -} -#enddecl(SRC1_F16) +#enddecl(SHMEM_SCALAR) #end(DECLS) #define(SHADER) +diagnostic(off, chromium.subgroup_matrix_uniformity); enable f16; enable chromium_experimental_subgroup_matrix; @@ -179,13 +142,6 @@ struct MulMatParams { DECLS -fn get_local_n(thread_id: u32) -> u32 { - return thread_id / WORKGROUP_SIZE_M; -} -fn get_local_m(thread_id: u32) -> u32 { - return thread_id % WORKGROUP_SIZE_M; -} - // Number of threads per workgroup: SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE // Shared memory src0: SUBGROUP_M * SUBGROUP_MATRIX_M * TILE_K // Shared memory src1: SUBGROUP_N * SUBGROUP_MATRIX_N * TILE_K @@ -197,21 +153,19 @@ override SUBGROUP_N: u32; override SUBGROUP_MATRIX_N: u32; override SUBGROUP_SIZE: u32; -// Warning: cannot be overrides, must match values in ggml-webgpu.cpp -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = 4u; -const TILE_N = 4u; - -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; override TILE_K: u32; +// Note: we assume TILE_K is divisible by SUBGROUP_MATRIX_K; +override SUBGROUP_MATRIX_K: u32; -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; +override TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; +override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M; +override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N; -var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>; -var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; +// Note: apparently current dawn doesn't like override constant shared memory size along with subgroup matrix loads +//var src0_shmem: array; +//var src1_shmem: array; +var src0_shmem: array; +var src1_shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -219,13 +173,15 @@ fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(subgroup_id) subgroup_id: u32) { let thread_id = local_id.x; - let local_m = get_local_m(thread_id); - let local_n = get_local_n(thread_id); + let subgroup_m = subgroup_id % SUBGROUP_M; + let subgroup_n = subgroup_id / SUBGROUP_M; let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); - let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let subgroups_m = (params.m + SUBGROUP_MATRIX_M - 1) / SUBGROUP_MATRIX_M; + let wg_m_count = subgroups_m / SUBGROUP_M; + let subgroups_n = (params.n + SUBGROUP_MATRIX_N - 1) / SUBGROUP_MATRIX_N; + let wg_n_count = subgroups_n / SUBGROUP_N; let wg_per_matrix = wg_m_count * wg_n_count; let batch_idx = wg_linear / wg_per_matrix; @@ -234,8 +190,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; - let output_row_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; - let output_col_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; @@ -250,76 +204,70 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc: array, TILE_M>; - var src0_sg_mat : subgroup_matrix_left; - var src1_sg_mat : subgroup_matrix_right; - var acc_sg_mat : subgroup_matrix_result; + var acc_sg_mat : subgroup_matrix_result; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_m = wg_m * WORKGROUP_SIZE_M * TILE_M + tile_m; + let global_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M + tile_m; let global_k = k_outer + tile_k; let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; - src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob + let src0_val = select( // taking a slight performance hit to avoid oob zero_val_src0(), src0[src0_idx/{{VEC_SIZE}}], global_m < params.m && global_k < params.k); + store_src0_shmem(src0_val, elem_idx); } for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_n = wg_n * WORKGROUP_SIZE_N * TILE_N + tile_n; + let global_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N + tile_n; let global_k = k_outer + tile_k; let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; - src1_shmem[elem_idx/{{VEC_SIZE}}] = select( + let src1_val = select( zero_val_src1(), src1[src1_idx/{{VEC_SIZE}}], global_n < params.n && global_k < params.k); - } + store_src1_shmem(src1_val, elem_idx); + } workgroupBarrier(); let k_end = min(TILE_K, params.k - k_outer); - for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { - var src0_tile: array<{{SRC0_TYPE}}, TILE_M>; - for (var tm = 0u; tm < TILE_M; tm++) { - let src0_m = local_m * TILE_M + tm; - let src0_idx = k_inner + src0_m * TILE_K; - src0_tile[tm] = src0_shmem[src0_idx/{{VEC_SIZE}}]; - } - for (var tn = 0u; tn < TILE_N; tn++) { - let src1_n = local_n * TILE_N + tn; - let src1_idx = src1_n * TILE_K + k_inner; - let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; - for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += mul_acc(src0_tile[tm], src1_vec); - } - } + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K) { + + let src0_shmem_idx = subgroup_m * SUBGROUP_MATRIX_M * TILE_K + k_inner; + let src0_sg_mat = subgroupMatrixLoad>( + &src0_shmem, + src0_shmem_idx, + false, + TILE_K + ); + + let src1_shmem_idx = subgroup_n * SUBGROUP_MATRIX_N * TILE_K + k_inner; + let src1_sg_mat = subgroupMatrixLoad>( + &src1_shmem, + src1_shmem_idx, + true, + TILE_K + ); + + acc_sg_mat = subgroupMatrixMultiplyAccumulate(src0_sg_mat, src1_sg_mat, acc_sg_mat); } workgroupBarrier(); } let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - for (var tn = 0u; tn < TILE_N; tn++) { - let global_row = output_row_base + tn; - if (global_row < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { - let global_col = output_col_base + tm; - if (global_col < params.m) { - let dst_idx = dst_batch_offset + global_row * params.m + global_col; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); - } - } - } - } + let dst_row_base = (wg_n * SUBGROUP_N + subgroup_n) * SUBGROUP_MATRIX_N; + let dst_col_base = (wg_m * SUBGROUP_M + subgroup_m) * SUBGROUP_MATRIX_M; + let dst_idx = dst_batch_offset + dst_row_base * params.m + dst_col_base; + subgroupMatrixStore(&dst, dst_idx, acc_sg_mat, true, params.m); } #end(SHADER) From a80e2bb1242a0c9c32b2ea9f259cac6255a23f06 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 13:55:57 +0700 Subject: [PATCH 15/40] Working subgroup matrix tiling --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 21 ++-- .../mul_mat_subgroup_matrix.tmpl.wgsl | 101 ++++++++++-------- 2 files changed, 70 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e6e317819c67d..8f9549a847edb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -84,8 +84,13 @@ #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 /* End Constants */ @@ -933,9 +938,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t subgroups_m = (dst->ne[0] + ctx->subgroup_matrix_config.M - 1) / ctx->subgroup_matrix_config.M; - wg_m = subgroups_m / WEBGPU_MUL_MAT_SUBGROUP_M; + uint32_t subgroups_per_wg_m = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + wg_m = (subgroups_m + subgroups_per_wg_m - 1) / subgroups_per_wg_m; uint32_t subgroups_n = (dst->ne[1] + ctx->subgroup_matrix_config.N - 1) / ctx->subgroup_matrix_config.N; - wg_n = subgroups_n / WEBGPU_MUL_MAT_SUBGROUP_N; + uint32_t subgroups_per_wg_n = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + wg_n = (subgroups_n + subgroups_per_wg_n - 1) / subgroups_per_wg_n; } else { uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; @@ -1714,14 +1721,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { if (webgpu_ctx->supports_subgroup_matrix) { mul_mat_opt_constants.push_back({ .key = "SUBGROUP_M", .value = WEBGPU_MUL_MAT_SUBGROUP_M }); mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_M", .value = static_cast(webgpu_ctx->subgroup_matrix_config.M) }); + { .key = "SUBGROUP_MATRIX_M_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.M) }); mul_mat_opt_constants.push_back({ .key = "SUBGROUP_N", .value = WEBGPU_MUL_MAT_SUBGROUP_N }); mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_N", .value = static_cast(webgpu_ctx->subgroup_matrix_config.N) }); + { .key = "SUBGROUP_MATRIX_N_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.N) }); mul_mat_opt_constants.push_back( { .key = "SUBGROUP_SIZE", .value = static_cast(webgpu_ctx->subgroup_size) }); mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_K", .value = static_cast(webgpu_ctx->subgroup_matrix_config.K) }); + { .key = "SUBGROUP_MATRIX_K_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.K) }); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index aedbd61cfa9d1..b675d781fcd9a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -142,30 +142,29 @@ struct MulMatParams { DECLS -// Number of threads per workgroup: SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE -// Shared memory src0: SUBGROUP_M * SUBGROUP_MATRIX_M * TILE_K -// Shared memory src1: SUBGROUP_N * SUBGROUP_MATRIX_N * TILE_K -// TILE_K must be divisible by SUBGROUP_MATRIX_K - override SUBGROUP_M: u32; -override SUBGROUP_MATRIX_M: u32; +override SUBGROUP_MATRIX_M_SIZE: u32; override SUBGROUP_N: u32; -override SUBGROUP_MATRIX_N: u32; +override SUBGROUP_MATRIX_N_SIZE: u32; override SUBGROUP_SIZE: u32; +// Note: must match values in ggml-webgpu.cpp +const SUBGROUP_MATRIX_M = 4u; +const SUBGROUP_MATRIX_N = 2u; + override TILE_K: u32; // Note: we assume TILE_K is divisible by SUBGROUP_MATRIX_K; -override SUBGROUP_MATRIX_K: u32; +override SUBGROUP_MATRIX_K_SIZE: u32; override TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; -override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M; -override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N; +override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; // Note: apparently current dawn doesn't like override constant shared memory size along with subgroup matrix loads //var src0_shmem: array; //var src1_shmem: array; -var src0_shmem: array; -var src1_shmem: array; +var src0_shmem: array; +var src1_shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -178,10 +177,12 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let subgroups_m = (params.m + SUBGROUP_MATRIX_M - 1) / SUBGROUP_MATRIX_M; - let wg_m_count = subgroups_m / SUBGROUP_M; - let subgroups_n = (params.n + SUBGROUP_MATRIX_N - 1) / SUBGROUP_MATRIX_N; - let wg_n_count = subgroups_n / SUBGROUP_N; + let subgroups_m = (params.m + SUBGROUP_MATRIX_M_SIZE - 1) / SUBGROUP_MATRIX_M_SIZE; + let subgroups_per_wg_m = SUBGROUP_M * SUBGROUP_MATRIX_M; + let wg_m_count = (subgroups_m + subgroups_per_wg_m - 1) / subgroups_per_wg_m; + let subgroups_n = (params.n + SUBGROUP_MATRIX_N_SIZE - 1) / SUBGROUP_MATRIX_N_SIZE; + let subgroups_per_wg_n = SUBGROUP_N * SUBGROUP_MATRIX_N; + let wg_n_count = (subgroups_n + subgroups_per_wg_n - 1) / subgroups_per_wg_n; let wg_per_matrix = wg_m_count * wg_n_count; let batch_idx = wg_linear / wg_per_matrix; @@ -190,7 +191,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; - let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; @@ -204,14 +204,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc_sg_mat : subgroup_matrix_result; + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M + tile_m; + let global_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE + tile_m; let global_k = k_outer + tile_k; let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob @@ -224,7 +224,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; - let global_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N + tile_n; + let global_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE + tile_n; let global_k = k_outer + tile_k; let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; @@ -237,37 +237,48 @@ fn main(@builtin(global_invocation_id) global_id: vec3, workgroupBarrier(); - let k_end = min(TILE_K, params.k - k_outer); - - for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K) { - - let src0_shmem_idx = subgroup_m * SUBGROUP_MATRIX_M * TILE_K + k_inner; - let src0_sg_mat = subgroupMatrixLoad>( - &src0_shmem, - src0_shmem_idx, - false, - TILE_K - ); - - let src1_shmem_idx = subgroup_n * SUBGROUP_MATRIX_N * TILE_K + k_inner; - let src1_sg_mat = subgroupMatrixLoad>( - &src1_shmem, - src1_shmem_idx, - true, - TILE_K - ); - - acc_sg_mat = subgroupMatrixMultiplyAccumulate(src0_sg_mat, src1_sg_mat, acc_sg_mat); + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + src0_sg_mats[m] = subgroupMatrixLoad>( + &src0_shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &src1_shmem, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } + } } workgroupBarrier(); } let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - let dst_row_base = (wg_n * SUBGROUP_N + subgroup_n) * SUBGROUP_MATRIX_N; - let dst_col_base = (wg_m * SUBGROUP_M + subgroup_m) * SUBGROUP_MATRIX_M; - let dst_idx = dst_batch_offset + dst_row_base * params.m + dst_col_base; - subgroupMatrixStore(&dst, dst_idx, acc_sg_mat, true, params.m); + let dst_row_base = (wg_n * SUBGROUP_N + subgroup_n) * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let dst_col_base = (wg_m * SUBGROUP_M + subgroup_m) * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let global_row = dst_row_base + n * SUBGROUP_MATRIX_N_SIZE; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let global_col = dst_col_base + m * SUBGROUP_MATRIX_M_SIZE; + let dst_idx = dst_batch_offset + global_row * params.m + global_col; + subgroupMatrixStore(&dst, dst_idx, acc_sg_mat[m][n], true, params.m); + } + } } #end(SHADER) From e4fd0b5544eaffb01616f7e2f102c3f1dd84dfcd Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 14:32:56 +0700 Subject: [PATCH 16/40] Handle weirder sg matrix sizes (but still % sg matrix size) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 121 ++++++++++-------- .../mul_mat_subgroup_matrix.tmpl.wgsl | 23 ++-- 2 files changed, 78 insertions(+), 66 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 8f9549a847edb..b85dda6bcaa9e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -271,7 +271,8 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>>> + mul_mat_pipelines; // src0_type, src1_type, sg matrix, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline; @@ -931,18 +932,21 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (use_fast) { int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; + int sg_matrix = ctx->supports_subgroup_matrix && src0->ne[0] % ctx->subgroup_matrix_config.K == 0 && + dst->ne[0] % ctx->subgroup_matrix_config.M == 0 && + dst->ne[1] % ctx->subgroup_matrix_config.N == 0; + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][sg_matrix][vectorized]; uint32_t wg_m; uint32_t wg_n; - if (ctx->supports_subgroup_matrix) { + if (sg_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t subgroups_m = (dst->ne[0] + ctx->subgroup_matrix_config.M - 1) / ctx->subgroup_matrix_config.M; - uint32_t subgroups_per_wg_m = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; - wg_m = (subgroups_m + subgroups_per_wg_m - 1) / subgroups_per_wg_m; - uint32_t subgroups_n = (dst->ne[1] + ctx->subgroup_matrix_config.N - 1) / ctx->subgroup_matrix_config.N; - uint32_t subgroups_per_wg_n = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; - wg_n = (subgroups_n + subgroups_per_wg_n - 1) / subgroups_per_wg_n; + uint32_t wg_m_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + uint32_t wg_n_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; @@ -1713,61 +1717,66 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - // override constants - std::vector mul_mat_opt_constants(1); - mul_mat_opt_constants[0].key = "TILE_K"; - mul_mat_opt_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - if (webgpu_ctx->supports_subgroup_matrix) { - mul_mat_opt_constants.push_back({ .key = "SUBGROUP_M", .value = WEBGPU_MUL_MAT_SUBGROUP_M }); - mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_M_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.M) }); - mul_mat_opt_constants.push_back({ .key = "SUBGROUP_N", .value = WEBGPU_MUL_MAT_SUBGROUP_N }); - mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_N_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.N) }); - mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_SIZE", .value = static_cast(webgpu_ctx->subgroup_size) }); - mul_mat_opt_constants.push_back( - { .key = "SUBGROUP_MATRIX_K_SIZE", .value = static_cast(webgpu_ctx->subgroup_matrix_config.K) }); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + std::vector mul_mat_sg_mat_constants(7); + mul_mat_sg_mat_constants[0].key = "TILE_K"; + mul_mat_sg_mat_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_sg_mat_constants[1].key = "SUBGROUP_M"; + mul_mat_sg_mat_constants[1].value = WEBGPU_MUL_MAT_SUBGROUP_M; + mul_mat_sg_mat_constants[2].key = "SUBGROUP_N"; + mul_mat_sg_mat_constants[2].value = WEBGPU_MUL_MAT_SUBGROUP_N; + mul_mat_sg_mat_constants[3].key = "SUBGROUP_MATRIX_M_SIZE"; + mul_mat_sg_mat_constants[3].value = static_cast(webgpu_ctx->subgroup_matrix_config.M); + mul_mat_sg_mat_constants[4].key = "SUBGROUP_MATRIX_N_SIZE"; + mul_mat_sg_mat_constants[4].value = static_cast(webgpu_ctx->subgroup_matrix_config.N); + mul_mat_sg_mat_constants[5].key = "SUBGROUP_SIZE"; + mul_mat_sg_mat_constants[5].value = static_cast(webgpu_ctx->subgroup_size); + mul_mat_sg_mat_constants[6].key = "SUBGROUP_MATRIX_K_SIZE"; + mul_mat_sg_mat_constants[6].value = static_cast(webgpu_ctx->subgroup_matrix_config.K); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, - "mul_mat_subgroup_matrix_f32_f32", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + "mul_mat_subgroup_matrix_f32_f32", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32_vec, - "mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + "mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32, - "mul_mat_subgroup_matrix_f16_f32", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + "mul_mat_subgroup_matrix_f16_f32", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32_vec, - "mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + "mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16, - "mul_mat_subgroup_matrix_f16_f16", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + "mul_mat_subgroup_matrix_f16_f16", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec, - "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_opt_constants); - } else { - mul_mat_opt_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_opt_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, - "mul_mat_reg_tile_f32_f32_vec", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, - "mul_mat_reg_tile_f16_f32_vec", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_opt_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, - "mul_mat_reg_tile_f16_f16_vec", mul_mat_opt_constants); + "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_sg_mat_constants); } + + std::vector mul_mat_reg_tile_constants(3); + mul_mat_reg_tile_constants[0].key = "TILE_K"; + mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; + mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; + mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, + "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, + "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, + "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index b675d781fcd9a..9ec741489d2e9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -156,6 +156,9 @@ override TILE_K: u32; // Note: we assume TILE_K is divisible by SUBGROUP_MATRIX_K; override SUBGROUP_MATRIX_K_SIZE: u32; +override WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +override WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + override TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -177,12 +180,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let subgroups_m = (params.m + SUBGROUP_MATRIX_M_SIZE - 1) / SUBGROUP_MATRIX_M_SIZE; - let subgroups_per_wg_m = SUBGROUP_M * SUBGROUP_MATRIX_M; - let wg_m_count = (subgroups_m + subgroups_per_wg_m - 1) / subgroups_per_wg_m; - let subgroups_n = (params.n + SUBGROUP_MATRIX_N_SIZE - 1) / SUBGROUP_MATRIX_N_SIZE; - let subgroups_per_wg_n = SUBGROUP_N * SUBGROUP_MATRIX_N; - let wg_n_count = (subgroups_n + subgroups_per_wg_n - 1) / subgroups_per_wg_n; + let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; + let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; let batch_idx = wg_linear / wg_per_matrix; @@ -273,10 +272,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { let global_row = dst_row_base + n * SUBGROUP_MATRIX_N_SIZE; - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - let global_col = dst_col_base + m * SUBGROUP_MATRIX_M_SIZE; - let dst_idx = dst_batch_offset + global_row * params.m + global_col; - subgroupMatrixStore(&dst, dst_idx, acc_sg_mat[m][n], true, params.m); + if (global_row < params.n) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let global_col = dst_col_base + m * SUBGROUP_MATRIX_M_SIZE; + if (global_col < params.m) { + let dst_idx = dst_batch_offset + global_row * params.m + global_col; + subgroupMatrixStore(&dst, dst_idx, acc_sg_mat[m][n], true, params.m); + } + } } } } From b524249e394288067934a0eb36ec74135083b711 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 17:24:56 +0800 Subject: [PATCH 17/40] Working start to gemv --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 112 ++++--- .../wgsl-shaders/gemv_f16_f32.wgsl | 285 ++++++++++++++++++ 2 files changed, 354 insertions(+), 43 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b85dda6bcaa9e..4d2863bcc9227 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -275,6 +275,8 @@ struct webgpu_context_struct { mul_mat_pipelines; // src0_type, src1_type, sg matrix, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; + // Specialized gemv for f16/f32 + webgpu_pipeline mul_mat_gemv_pipeline; webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; @@ -911,52 +913,73 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } - - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - int sg_matrix = ctx->supports_subgroup_matrix && src0->ne[0] % ctx->subgroup_matrix_config.K == 0 && - dst->ne[0] % ctx->subgroup_matrix_config.M == 0 && - dst->ne[1] % ctx->subgroup_matrix_config.N == 0; - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][sg_matrix][vectorized]; - - uint32_t wg_m; - uint32_t wg_n; - if (sg_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; - wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; - uint32_t wg_n_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; - wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + // Use specialized gemv + if ((dst->ne[0] == 1 || dst->ne[1] == 1) && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + // gemv_fast: 256 threads per workgroup, computes 16 outputs per workgroup (written as 4x vec4) + // Uses cooperative reduction and vec4 operations (requires K % 4 == 0 and outputs % 16 == 0) + uint32_t output_elements = (dst->ne[0] == 1) ? dst->ne[1] : dst->ne[0]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + + // Use gemv_fast for larger vectors where reduction overhead pays off + // Requires K divisible by 4 for vec4 alignment, outputs divisible by 16 for optimal vec4 output + if (output_elements >= 64 && src0->ne[0] % 4 == 0 && output_elements % 16 == 0) { + // Each workgroup computes 16 consecutive outputs (4x vec4 writes) + uint32_t output_vec4_groups = output_elements / 16; + uint32_t wg_x = output_vec4_groups * batches; + + return ggml_backend_webgpu_build(ctx, ctx->mul_mat_gemv_pipeline, params, entries, wg_x); } else { - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; - wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + // Small vectors, unaligned K, or outputs not divisible by 16: use template shader + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } + } else { + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + use_fast = true; + break; + default: + break; + } + break; + default: + break; } - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; - } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + int sg_matrix = ctx->supports_subgroup_matrix && src0->ne[0] % ctx->subgroup_matrix_config.K == 0 && + dst->ne[0] % ctx->subgroup_matrix_config.M == 0 && + dst->ne[1] % ctx->subgroup_matrix_config.N == 0; + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][sg_matrix][vectorized]; + + uint32_t wg_m; + uint32_t wg_n; + if (sg_matrix) { + // The total number of subgroups/workgroups needed per matrix. + uint32_t wg_m_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + uint32_t wg_n_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + } else { + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1777,6 +1800,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); + + webgpu_ctx->mul_mat_gemv_pipeline = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32"); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl new file mode 100644 index 0000000000000..d56f4fe915b50 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl @@ -0,0 +1,285 @@ +enable f16; + +// Optimized GEMV shader for F16xF32 matrix-vector multiplication +// Handles both M=1 (row vector * matrix) and N=1 (matrix * column vector) cases +// Uses vectorized memory access and shared memory for better performance + +const WORKGROUP_SIZE: u32 = 256u; // Larger workgroup for better occupancy +const VECTOR_WIDTH: u32 = 4u; // Process 4 elements at a time with vec4 +const TILE_K: u32 = 128u; // Tile size along K dimension for cache efficiency +const OUTPUTS_PER_WG: u32 = 16u; // Each workgroup computes 16 outputs (written as 4x vec4) - OPTIMAL + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array>; // Matrix (N x K in vec4s) +@group(0) @binding(1) var src1: array>; // Vector (M x K or K in vec4s) +@group(0) @binding(2) var dst: array>; // Result vector (vec4 for bandwidth) + +@group(0) @binding(3) var params: MulMatParams; + +// Shared memory for collaborative loading and reduction +var shared_vector: array, TILE_K/4>; // Cache vector tile +var partial_sums: array; // For reduction (4 groups) + +// Helper function for vectorized dot product +fn dot_vec4_f16_f32(a: vec4, b: vec4) -> f32 { + return f32(a.x) * b.x + f32(a.y) * b.y + f32(a.z) * b.z + f32(a.w) * b.w; +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3 +) { + let thread_id = local_id.x; + + // Handle batch dimensions + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let output_elements = select(params.n, params.m, params.n == 1u); + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs (written as vec4) + // Using 2D dispatch to avoid exceeding 65535 limit per dimension + // wg_linear = wg_id.y * 65535 + wg_id.x + let wg_linear = wg_id.y * 65535u + wg_id.x; + let output_vec4_groups = (output_elements + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_vec4_groups; + let output_vec4_idx = wg_linear % output_vec4_groups; + let base_output_idx = output_vec4_idx * OUTPUTS_PER_WG; + + // Which of the 16 outputs does this thread belong to? + let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 256/16 = 16 + let output_offset = thread_id / threads_per_output; // 0-15 + let thread_in_group = thread_id % threads_per_output; // 0-15 + + if (batch_idx >= total_batches) { + return; + } + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + // Case 1: M == 1 (result is a row vector: 1 x N) + // Each workgroup computes OUTPUTS_PER_WG (16) consecutive output elements + // 256 threads split into 16 groups of 16, each group computes one output + if (params.n == 1u) { + let output_col_base = base_output_idx; + let output_col = output_col_base + output_offset; + + // Check bounds but don't early return (must hit all barriers) + let is_valid = output_col < params.m; + + var local_sum = 0.0; + let k_vec = params.k / VECTOR_WIDTH; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < k_vec; k_tile += TILE_K/VECTOR_WIDTH) { + let tile_size = min(TILE_K/VECTOR_WIDTH, k_vec - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id; i < tile_size; i += WORKGROUP_SIZE) { + let k_idx = (k_tile + i) * VECTOR_WIDTH; + if (k_idx < params.k) { + let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + + src12_idx * params.stride_12 + k_idx; + shared_vector[i] = src1[src1_idx / VECTOR_WIDTH]; + } + } + + workgroupBarrier(); + + // Each sub-group of 16 threads computes its own output + // thread_in_group = 0-15, provides stride for this sub-group + let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 16 + + if (is_valid) { + for (var i = thread_in_group; i < tile_size; i += threads_per_output) { + let k_idx = (k_tile + i) * VECTOR_WIDTH; + // Removed redundant k_idx < params.k check (guaranteed by tile_size) + let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + + src02_idx * params.stride_02 + output_col * params.stride_01 + k_idx; + let a = src0[src0_idx / VECTOR_WIDTH]; + let b = shared_vector[i]; + local_sum += dot_vec4_f16_f32(a, b); + } + } + + workgroupBarrier(); + } + + // Handle remaining elements (K % VECTOR_WIDTH) + if (is_valid) { + let k_remainder_start = k_vec * VECTOR_WIDTH; + if (thread_in_group < (params.k - k_remainder_start)) { + let k_idx = k_remainder_start + thread_in_group; + let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + + src02_idx * params.stride_02 + output_col * params.stride_01 + k_idx; + let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + + src12_idx * params.stride_12 + k_idx; + // Read individual elements (last vec4 might be partial) + let vec_idx = k_idx / VECTOR_WIDTH; + let elem_idx = k_idx % VECTOR_WIDTH; + let a_vec = src0[src0_idx / VECTOR_WIDTH]; + let b_vec = src1[src1_idx / VECTOR_WIDTH]; + local_sum += f32(a_vec[elem_idx]) * b_vec[elem_idx]; + } + } + + // Store partial sums and reduce within each sub-group (16 threads per output) + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + + // Reduce within each sub-group: 16 threads → 1 result + // Each sub-group occupies 16 consecutive slots in partial_sums + let group_base = output_offset * (WORKGROUP_SIZE / OUTPUTS_PER_WG); // 0, 16, 32, ..., 240 + + // Reduction for 16 threads: 16 → 8 → 4 → 2 → 1 (loop version for correctness) + for (var stride = 8u; stride > 0u; stride = stride / 2u) { + if (thread_in_group < stride) { + partial_sums[group_base + thread_in_group] += partial_sums[group_base + thread_in_group + stride]; + } + workgroupBarrier(); + } + + // First thread of each sub-group has the result + // Threads 0, 16, 32, 48, ... 240 hold the 16 output values + if (thread_id == 0u && output_col_base < params.m) { + // Gather 16 results and write as 4 vec4s + let result_vec0 = vec4(partial_sums[0], partial_sums[16], partial_sums[32], partial_sums[48]); + let result_vec1 = vec4(partial_sums[64], partial_sums[80], partial_sums[96], partial_sums[112]); + let result_vec2 = vec4(partial_sums[128], partial_sums[144], partial_sums[160], partial_sums[176]); + let result_vec3 = vec4(partial_sums[192], partial_sums[208], partial_sums[224], partial_sums[240]); + + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + + dst2_idx * dst2_stride + output_col_base; + dst[dst_idx / VECTOR_WIDTH] = result_vec0; + dst[dst_idx / VECTOR_WIDTH + 1u] = result_vec1; + dst[dst_idx / VECTOR_WIDTH + 2u] = result_vec2; + dst[dst_idx / VECTOR_WIDTH + 3u] = result_vec3; + } + } + // Case 2: N == 1 (result is a column vector: M x 1) + // Each workgroup computes OUTPUTS_PER_WG (16) consecutive output elements + // 256 threads split into 16 groups of 16, each group computes one output + else if (params.m == 1u) { + let output_row_base = base_output_idx; + let output_row = output_row_base + output_offset; + + // Check bounds but don't early return (must hit all barriers) + let is_valid = output_row < params.n; + + var local_sum = 0.0; + let k_vec = params.k / VECTOR_WIDTH; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < k_vec; k_tile += TILE_K/VECTOR_WIDTH) { + let tile_size = min(TILE_K/VECTOR_WIDTH, k_vec - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + // Note: In this case, src0 is the vector input + for (var i = thread_id; i < tile_size; i += WORKGROUP_SIZE) { + let k_idx = (k_tile + i) * VECTOR_WIDTH; + if (k_idx < params.k) { + let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + + src02_idx * params.stride_02 + k_idx; + shared_vector[i] = vec4(src0[src0_idx / VECTOR_WIDTH]); + } + } + + workgroupBarrier(); + + // Each sub-group of 16 threads computes its own output + // thread_in_group = 0-15, provides stride for this sub-group + let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 16 + + if (is_valid) { + for (var i = thread_in_group; i < tile_size; i += threads_per_output) { + let k_idx = (k_tile + i) * VECTOR_WIDTH; + // Removed redundant k_idx < params.k check (guaranteed by tile_size) + let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + + src12_idx * params.stride_12 + output_row * params.stride_11 + k_idx; + let a = shared_vector[i]; // from src0 + let b = src1[src1_idx / VECTOR_WIDTH]; + local_sum += dot(a, b); + } + } + + workgroupBarrier(); + } + + // Handle remaining elements (K % VECTOR_WIDTH) + if (is_valid) { + let k_remainder_start = k_vec * VECTOR_WIDTH; + if (thread_in_group < (params.k - k_remainder_start)) { + let k_idx = k_remainder_start + thread_in_group; + let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + + src02_idx * params.stride_02 + k_idx; + let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + + src12_idx * params.stride_12 + output_row * params.stride_11 + k_idx; + let vec_idx = k_idx / VECTOR_WIDTH; + let elem_idx = k_idx % VECTOR_WIDTH; + let a_vec = src0[src0_idx / VECTOR_WIDTH]; + let b_vec = src1[src1_idx / VECTOR_WIDTH]; + local_sum += f32(a_vec[elem_idx]) * b_vec[elem_idx]; + } + } + + // Store partial sums and reduce within each sub-group (16 threads per output) + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + + // Reduce within each sub-group: 16 threads → 1 result + // Each sub-group occupies 16 consecutive slots in partial_sums + let group_base = output_offset * (WORKGROUP_SIZE / OUTPUTS_PER_WG); // 0, 16, 32, ..., 240 + + // Reduction for 16 threads: 16 → 8 → 4 → 2 → 1 (loop version for correctness) + for (var stride = 8u; stride > 0u; stride = stride / 2u) { + if (thread_in_group < stride) { + partial_sums[group_base + thread_in_group] += partial_sums[group_base + thread_in_group + stride]; + } + workgroupBarrier(); + } + + // First thread of each sub-group has the result + // Threads 0, 16, 32, 48, ... 240 hold the 16 output values + if (thread_id == 0u && output_row_base < params.n) { + // Gather 16 results and write as 4 vec4s + let result_vec0 = vec4(partial_sums[0], partial_sums[16], partial_sums[32], partial_sums[48]); + let result_vec1 = vec4(partial_sums[64], partial_sums[80], partial_sums[96], partial_sums[112]); + let result_vec2 = vec4(partial_sums[128], partial_sums[144], partial_sums[160], partial_sums[176]); + let result_vec3 = vec4(partial_sums[192], partial_sums[208], partial_sums[224], partial_sums[240]); + + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + + dst2_idx * dst2_stride + output_row_base * params.m; + dst[dst_idx / VECTOR_WIDTH] = result_vec0; + dst[dst_idx / VECTOR_WIDTH + 1u] = result_vec1; + dst[dst_idx / VECTOR_WIDTH + 2u] = result_vec2; + dst[dst_idx / VECTOR_WIDTH + 3u] = result_vec3; + } + } +} \ No newline at end of file From 749a791e919a9063b6e23eb5dc738c2700997df9 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 19:40:19 +0800 Subject: [PATCH 18/40] working f16 accumulation with shared memory staging --- .../mul_mat_subgroup_matrix.tmpl.wgsl | 84 ++++++++++++------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 9ec741489d2e9..2bbe91532f1f6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -72,10 +72,10 @@ fn zero_val_src0() -> {{SRC0_TYPE}} { } fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - src0_shmem[idx] = f32(val.x); - src0_shmem[idx + 1] = f32(val.y); - src0_shmem[idx + 2] = f32(val.z); - src0_shmem[idx + 3] = f32(val.w); + src0_shmem[idx] = f16(val.x); + src0_shmem[idx + 1] = f16(val.y); + src0_shmem[idx + 2] = f16(val.z); + src0_shmem[idx + 3] = f16(val.w); } fn zero_val_src1() -> {{SRC1_TYPE}} { @@ -83,10 +83,10 @@ fn zero_val_src1() -> {{SRC1_TYPE}} { } fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - src1_shmem[idx] = f32(val.x); - src1_shmem[idx + 1] = f32(val.y); - src1_shmem[idx + 2] = f32(val.z); - src1_shmem[idx + 3] = f32(val.w); + src1_shmem[idx] = f16(val.x); + src1_shmem[idx + 1] = f16(val.y); + src1_shmem[idx + 2] = f16(val.z); + src1_shmem[idx + 3] = f16(val.w); } #enddecl(SHMEM_VEC) @@ -96,7 +96,7 @@ fn zero_val_src0() -> {{SRC0_TYPE}} { } fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - src0_shmem[idx] = f32(val); + src0_shmem[idx] = f16(val); } fn zero_val_src1() -> {{SRC1_TYPE}} { @@ -104,7 +104,7 @@ fn zero_val_src1() -> {{SRC1_TYPE}} { } fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - src1_shmem[idx] = f32(val); + src1_shmem[idx] = f16(val); } #enddecl(SHMEM_SCALAR) @@ -163,11 +163,16 @@ override TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; +override SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse src0_shmem for accumulation matrices +override SHMEM_SIZE = max(TILE_SRC0_SHMEM, SG_MAT_ACCUM_SHMEM); + // Note: apparently current dawn doesn't like override constant shared memory size along with subgroup matrix loads -//var src0_shmem: array; +//var src0_shmem: array; //var src1_shmem: array; -var src0_shmem: array; -var src1_shmem: array; +var src0_shmem: array; +var src1_shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -203,7 +208,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -239,9 +244,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; - var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - src0_sg_mats[m] = subgroupMatrixLoad>( + src0_sg_mats[m] = subgroupMatrixLoad>( &src0_shmem, src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, false, @@ -251,7 +256,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - let src1_sg_mat = subgroupMatrixLoad>( + let src1_sg_mat = subgroupMatrixLoad>( &src1_shmem, src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, true, @@ -267,19 +272,42 @@ fn main(@builtin(global_invocation_id) global_id: vec3, } let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - let dst_row_base = (wg_n * SUBGROUP_N + subgroup_n) * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - let dst_col_base = (wg_m * SUBGROUP_M + subgroup_m) * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + + // Stage the subgroup matrix tiles into shared memory + // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). + let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; + let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - let global_row = dst_row_base + n * SUBGROUP_MATRIX_N_SIZE; - if (global_row < params.n) { - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - let global_col = dst_col_base + m * SUBGROUP_MATRIX_M_SIZE; - if (global_col < params.m) { - let dst_idx = dst_batch_offset + global_row * params.m + global_col; - subgroupMatrixStore(&dst, dst_idx, acc_sg_mat[m][n], true, params.m); - } - } + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&src0_shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } + } + + workgroupBarrier(); + + // Cooperative write: iterate over the entire workgroup tile + let tile_rows = WG_N_SG_TILE_SIZE; + let tile_cols = WG_M_SG_TILE_SIZE; + let total_tile_elems = tile_rows * tile_cols; + let tile_dst_row_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_dst_col_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + for (var idx = thread_id; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE) { + let local_row = idx / WG_TILE_STRIDE; + let local_col = idx % WG_TILE_STRIDE; + + let global_row = tile_dst_row_base + local_row; + let global_col = tile_dst_col_base + local_col; + + if (global_row < params.n && global_col < params.m) { + let dst_idx = dst_batch_offset + global_row * params.m + global_col; + dst[dst_idx] = f32(src0_shmem[idx]); } } } From abaf12e5f9b891c08f1af4394906b869aca7d956 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 19:42:05 +0800 Subject: [PATCH 19/40] Print out available subgroup matrix configurations --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4d2863bcc9227..5bb692439d1b8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2317,6 +2317,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t info.nextInChain = &subgroup_matrix_configs; ctx->adapter.GetInfo(&info); + // print configs + for (int i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + std::cout << "ggml_webgpu: Subgroup Matrix Config " << i << ":\n"; + std::cout << " M: " << config.M << "\n"; + std::cout << " N: " << config.N << "\n"; + std::cout << " K: " << config.K << "\n"; + std::cout << " Component Type: " << static_cast(config.componentType) << "\n"; + std::cout << " Result Type: " << static_cast(config.resultComponentType) << "\n"; + } + ctx->subgroup_matrix_config = *subgroup_matrix_configs.configs; wgpu::SupportedFeatures features; ctx->adapter.GetFeatures(&features); From 54c31c101534cb138b200042da67f2825041b5ea Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 26 Oct 2025 21:21:44 +0800 Subject: [PATCH 20/40] Vectorize dst stores for sg matrix shader --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 76 +++++++++---------- .../mul_mat_subgroup_matrix.tmpl.wgsl | 23 ++++-- 2 files changed, 54 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5bb692439d1b8..87347232ea68d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -271,8 +271,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; - std::map>>> - mul_mat_pipelines; // src0_type, src1_type, sg matrix, vectorized + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; // Specialized gemv for f16/f32 @@ -925,7 +924,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (output_elements >= 64 && src0->ne[0] % 4 == 0 && output_elements % 16 == 0) { // Each workgroup computes 16 consecutive outputs (4x vec4 writes) uint32_t output_vec4_groups = output_elements / 16; - uint32_t wg_x = output_vec4_groups * batches; + uint32_t wg_x = output_vec4_groups * batches; return ggml_backend_webgpu_build(ctx, ctx->mul_mat_gemv_pipeline, params, entries, wg_x); } else { @@ -954,14 +953,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (use_fast) { int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - int sg_matrix = ctx->supports_subgroup_matrix && src0->ne[0] % ctx->subgroup_matrix_config.K == 0 && - dst->ne[0] % ctx->subgroup_matrix_config.M == 0 && - dst->ne[1] % ctx->subgroup_matrix_config.N == 0; - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][sg_matrix][vectorized]; + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; - if (sg_matrix) { + if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; @@ -1757,52 +1753,52 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_sg_mat_constants[6].key = "SUBGROUP_MATRIX_K_SIZE"; mul_mat_sg_mat_constants[6].value = static_cast(webgpu_ctx->subgroup_matrix_config.K); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1][0] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, "mul_mat_subgroup_matrix_f32_f32", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1][1] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32_vec, "mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1][0] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32, "mul_mat_subgroup_matrix_f16_f32", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1][1] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32_vec, "mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1][0] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16, "mul_mat_subgroup_matrix_f16_f16", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1][1] = + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec, "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_sg_mat_constants); + } else { + std::vector mul_mat_reg_tile_constants(3); + mul_mat_reg_tile_constants[0].key = "TILE_K"; + mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; + mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; + mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, + "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, + "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, + "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); } + webgpu_ctx->mul_mat_gemv_pipeline = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32"); - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - - webgpu_ctx->mul_mat_gemv_pipeline = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32"); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 2bbe91532f1f6..c8757bf207251 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -5,7 +5,7 @@ "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "f32", + "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -25,7 +25,7 @@ "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "f32", + "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -45,7 +45,7 @@ "REPLS": { "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", - "DST_TYPE" : "f32", + "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -88,6 +88,15 @@ fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { src1_shmem[idx + 2] = f16(val.z); src1_shmem[idx + 3] = f16(val.w); } + +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(src0_shmem[shmem_idx]), + f32(src0_shmem[shmem_idx + 1]), + f32(src0_shmem[shmem_idx + 2]), + f32(src0_shmem[shmem_idx + 3]) + ); +} #enddecl(SHMEM_VEC) #decl(SHMEM_SCALAR) @@ -106,6 +115,10 @@ fn zero_val_src1() -> {{SRC1_TYPE}} { fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { src1_shmem[idx] = f16(val); } + +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(src0_shmem[shmem_idx]); +} #enddecl(SHMEM_SCALAR) #end(DECLS) @@ -298,7 +311,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let tile_dst_row_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; let tile_dst_col_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; - for (var idx = thread_id; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE) { + for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let local_row = idx / WG_TILE_STRIDE; let local_col = idx % WG_TILE_STRIDE; @@ -307,7 +320,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, if (global_row < params.n && global_col < params.m) { let dst_idx = dst_batch_offset + global_row * params.m + global_col; - dst[dst_idx] = f32(src0_shmem[idx]); + store_dst(idx, dst_idx/{{VEC_SIZE}}); } } } From 0f6e38da4f08c9c58b88629497151d0a81f2d527 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 12:41:10 +0800 Subject: [PATCH 21/40] Gemv working scalar --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 92 ++++-- .../ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl | 212 +++++++++++++ .../wgsl-shaders/gemv_f16_f32.wgsl | 285 ------------------ 3 files changed, 273 insertions(+), 316 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 87347232ea68d..f8ab01abc6fba 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -92,6 +92,12 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +// GEMV constants +#define WEBGPU_GEMV_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide gemv wg size +#define WEBGPU_GEMV_OUTPUTS_PER_WG 16 +#define WEBGPU_GEMV_TILE_K 128 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -272,10 +278,10 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> gemv_pipelines; // src0_type, src1_type, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; // Specialized gemv for f16/f32 - webgpu_pipeline mul_mat_gemv_pipeline; webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; @@ -564,6 +570,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, + uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); @@ -609,7 +616,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & #endif pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, 1, 1); + pass.DispatchWorkgroups(wg_x, wg_y, 1); pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -821,7 +828,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, 1, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -911,26 +918,38 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; + uint32_t wg_y = 1; - // Use specialized gemv - if ((dst->ne[0] == 1 || dst->ne[1] == 1) && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - // gemv_fast: 256 threads per workgroup, computes 16 outputs per workgroup (written as 4x vec4) - // Uses cooperative reduction and vec4 operations (requires K % 4 == 0 and outputs % 16 == 0) - uint32_t output_elements = (dst->ne[0] == 1) ? dst->ne[1] : dst->ne[0]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - - // Use gemv_fast for larger vectors where reduction overhead pays off - // Requires K divisible by 4 for vec4 alignment, outputs divisible by 16 for optimal vec4 output - if (output_elements >= 64 && src0->ne[0] % 4 == 0 && output_elements % 16 == 0) { - // Each workgroup computes 16 consecutive outputs (4x vec4 writes) - uint32_t output_vec4_groups = output_elements / 16; - uint32_t wg_x = output_vec4_groups * batches; - - return ggml_backend_webgpu_build(ctx, ctx->mul_mat_gemv_pipeline, params, entries, wg_x); - } else { - // Small vectors, unaligned K, or outputs not divisible by 16: use template shader - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + bool use_gemv = false; + if (dst->ne[1] == 1) { + switch (src1->type) { + case GGML_TYPE_F16: + use_gemv = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + use_gemv = true; + break; + default: + break; + } + default: + break; } + } + + if (use_gemv) { + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = (dst->ne[0] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1) / WEBGPU_GEMV_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + ctx->limits.maxComputeWorkgroupsPerDimension; + int vectorized = src0->ne[0] % 4 == 0 && src1->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0; + pipeline = ctx->gemv_pipelines[src0->type][src1->type][vectorized]; + } else { bool use_fast = false; switch (src1->type) { @@ -973,9 +992,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, } wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } - - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1691,12 +1709,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_mul_mat_f32_f32, "mul_mat_f32_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_mul_mat_f16_f16, "mul_mat_f16_f16"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_mul_mat_f16_f32, "mul_mat_f16_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], @@ -1796,9 +1808,27 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); } - webgpu_ctx->mul_mat_gemv_pipeline = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32"); + std::vector gemv_constants(3); + gemv_constants[0].key = "WORKGROUP_SIZE"; + gemv_constants[0].value = WEBGPU_GEMV_WG_SIZE; + gemv_constants[1].key = "TILE_K"; + gemv_constants[1].value = WEBGPU_GEMV_TILE_K; + gemv_constants[2].key = "OUTPUTS_PER_WG"; + gemv_constants[2].value = WEBGPU_GEMV_OUTPUTS_PER_WG; + + webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32, "gemv_f32_f32", gemv_constants); + webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32_vec, "gemv_f32_f32_vec", gemv_constants); + webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32", gemv_constants); + webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32_vec, "gemv_f16_f32_vec", gemv_constants); + webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16, "gemv_f16_f16", gemv_constants); + webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16_vec, "gemv_f16_f16_vec", gemv_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl new file mode 100644 index 0000000000000..3964352f28ce8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl @@ -0,0 +1,212 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["VEC"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SCALAR"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["VEC"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SCALAR"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : "4", + }, + "DECLS": ["VEC"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE": "f32", + "VEC_SIZE" : "1", + }, + "DECLS": ["SCALAR"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +} + +fn store_val(group_base: u32) -> vec4 { + return vec4(partial_sums[group_base], + partial_sums[group_base + THREADS_PER_OUTPUT], + partial_sums[group_base + THREADS_PER_OUTPUT * 2], + partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +} +#enddecl(VEC) + +#decl(SCALAR) +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(src0_val) * f32(src1_val); +} + +fn store_val(group_base: u32) -> f32 { + return partial_sums[group_base]; +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +override WORKGROUP_SIZE: u32; +override TILE_K: u32; +override OUTPUTS_PER_WG: u32; +override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; + +// Shared memory for collaborative loading and reduction +var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile +var partial_sums: array; // For reduction + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let thread_id = local_id.x; + + // Handle batch dimensions + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + // Which of the outputs does this thread belong to? + let thread_group = thread_id / THREADS_PER_OUTPUT; + let thread_in_group = thread_id % THREADS_PER_OUTPUT; + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs + let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + + var local_sum = 0.0; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { + let tile_size = min(TILE_K, params.k - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { + shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + } + + workgroupBarrier(); + + if (output_row < params.m) { + for (var i = thread_in_group * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { + let a = src0[(src0_idx_base + k_tile + i) / {{VEC_SIZE}}]; + let b = shared_vector[i / {{VEC_SIZE}}]; + local_sum += mul_acc(a, b); + } + } + + workgroupBarrier(); + } + + // Store partial sums and reduce within each partition + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + let group_base = thread_group * THREADS_PER_OUTPUT; + let thread_base = group_base + thread_in_group; + var offset = THREADS_PER_OUTPUT / 2; + while (offset > 0) { + if (thread_in_group < offset) { + partial_sums[thread_base] += partial_sums[thread_base + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + + // Store back to global memory + if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { + dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + } +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl deleted file mode 100644 index d56f4fe915b50..0000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gemv_f16_f32.wgsl +++ /dev/null @@ -1,285 +0,0 @@ -enable f16; - -// Optimized GEMV shader for F16xF32 matrix-vector multiplication -// Handles both M=1 (row vector * matrix) and N=1 (matrix * column vector) cases -// Uses vectorized memory access and shared memory for better performance - -const WORKGROUP_SIZE: u32 = 256u; // Larger workgroup for better occupancy -const VECTOR_WIDTH: u32 = 4u; // Process 4 elements at a time with vec4 -const TILE_K: u32 = 128u; // Tile size along K dimension for cache efficiency -const OUTPUTS_PER_WG: u32 = 16u; // Each workgroup computes 16 outputs (written as 4x vec4) - OPTIMAL - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array>; // Matrix (N x K in vec4s) -@group(0) @binding(1) var src1: array>; // Vector (M x K or K in vec4s) -@group(0) @binding(2) var dst: array>; // Result vector (vec4 for bandwidth) - -@group(0) @binding(3) var params: MulMatParams; - -// Shared memory for collaborative loading and reduction -var shared_vector: array, TILE_K/4>; // Cache vector tile -var partial_sums: array; // For reduction (4 groups) - -// Helper function for vectorized dot product -fn dot_vec4_f16_f32(a: vec4, b: vec4) -> f32 { - return f32(a.x) * b.x + f32(a.y) * b.y + f32(a.z) * b.z + f32(a.w) * b.w; -} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3 -) { - let thread_id = local_id.x; - - // Handle batch dimensions - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let output_elements = select(params.n, params.m, params.n == 1u); - - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs (written as vec4) - // Using 2D dispatch to avoid exceeding 65535 limit per dimension - // wg_linear = wg_id.y * 65535 + wg_id.x - let wg_linear = wg_id.y * 65535u + wg_id.x; - let output_vec4_groups = (output_elements + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_vec4_groups; - let output_vec4_idx = wg_linear % output_vec4_groups; - let base_output_idx = output_vec4_idx * OUTPUTS_PER_WG; - - // Which of the 16 outputs does this thread belong to? - let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 256/16 = 16 - let output_offset = thread_id / threads_per_output; // 0-15 - let thread_in_group = thread_id % threads_per_output; // 0-15 - - if (batch_idx >= total_batches) { - return; - } - - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - // Case 1: M == 1 (result is a row vector: 1 x N) - // Each workgroup computes OUTPUTS_PER_WG (16) consecutive output elements - // 256 threads split into 16 groups of 16, each group computes one output - if (params.n == 1u) { - let output_col_base = base_output_idx; - let output_col = output_col_base + output_offset; - - // Check bounds but don't early return (must hit all barriers) - let is_valid = output_col < params.m; - - var local_sum = 0.0; - let k_vec = params.k / VECTOR_WIDTH; - - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < k_vec; k_tile += TILE_K/VECTOR_WIDTH) { - let tile_size = min(TILE_K/VECTOR_WIDTH, k_vec - k_tile); - - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id; i < tile_size; i += WORKGROUP_SIZE) { - let k_idx = (k_tile + i) * VECTOR_WIDTH; - if (k_idx < params.k) { - let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + - src12_idx * params.stride_12 + k_idx; - shared_vector[i] = src1[src1_idx / VECTOR_WIDTH]; - } - } - - workgroupBarrier(); - - // Each sub-group of 16 threads computes its own output - // thread_in_group = 0-15, provides stride for this sub-group - let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 16 - - if (is_valid) { - for (var i = thread_in_group; i < tile_size; i += threads_per_output) { - let k_idx = (k_tile + i) * VECTOR_WIDTH; - // Removed redundant k_idx < params.k check (guaranteed by tile_size) - let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + - src02_idx * params.stride_02 + output_col * params.stride_01 + k_idx; - let a = src0[src0_idx / VECTOR_WIDTH]; - let b = shared_vector[i]; - local_sum += dot_vec4_f16_f32(a, b); - } - } - - workgroupBarrier(); - } - - // Handle remaining elements (K % VECTOR_WIDTH) - if (is_valid) { - let k_remainder_start = k_vec * VECTOR_WIDTH; - if (thread_in_group < (params.k - k_remainder_start)) { - let k_idx = k_remainder_start + thread_in_group; - let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + - src02_idx * params.stride_02 + output_col * params.stride_01 + k_idx; - let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + - src12_idx * params.stride_12 + k_idx; - // Read individual elements (last vec4 might be partial) - let vec_idx = k_idx / VECTOR_WIDTH; - let elem_idx = k_idx % VECTOR_WIDTH; - let a_vec = src0[src0_idx / VECTOR_WIDTH]; - let b_vec = src1[src1_idx / VECTOR_WIDTH]; - local_sum += f32(a_vec[elem_idx]) * b_vec[elem_idx]; - } - } - - // Store partial sums and reduce within each sub-group (16 threads per output) - partial_sums[thread_id] = local_sum; - workgroupBarrier(); - - // Reduce within each sub-group: 16 threads → 1 result - // Each sub-group occupies 16 consecutive slots in partial_sums - let group_base = output_offset * (WORKGROUP_SIZE / OUTPUTS_PER_WG); // 0, 16, 32, ..., 240 - - // Reduction for 16 threads: 16 → 8 → 4 → 2 → 1 (loop version for correctness) - for (var stride = 8u; stride > 0u; stride = stride / 2u) { - if (thread_in_group < stride) { - partial_sums[group_base + thread_in_group] += partial_sums[group_base + thread_in_group + stride]; - } - workgroupBarrier(); - } - - // First thread of each sub-group has the result - // Threads 0, 16, 32, 48, ... 240 hold the 16 output values - if (thread_id == 0u && output_col_base < params.m) { - // Gather 16 results and write as 4 vec4s - let result_vec0 = vec4(partial_sums[0], partial_sums[16], partial_sums[32], partial_sums[48]); - let result_vec1 = vec4(partial_sums[64], partial_sums[80], partial_sums[96], partial_sums[112]); - let result_vec2 = vec4(partial_sums[128], partial_sums[144], partial_sums[160], partial_sums[176]); - let result_vec3 = vec4(partial_sums[192], partial_sums[208], partial_sums[224], partial_sums[240]); - - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + - dst2_idx * dst2_stride + output_col_base; - dst[dst_idx / VECTOR_WIDTH] = result_vec0; - dst[dst_idx / VECTOR_WIDTH + 1u] = result_vec1; - dst[dst_idx / VECTOR_WIDTH + 2u] = result_vec2; - dst[dst_idx / VECTOR_WIDTH + 3u] = result_vec3; - } - } - // Case 2: N == 1 (result is a column vector: M x 1) - // Each workgroup computes OUTPUTS_PER_WG (16) consecutive output elements - // 256 threads split into 16 groups of 16, each group computes one output - else if (params.m == 1u) { - let output_row_base = base_output_idx; - let output_row = output_row_base + output_offset; - - // Check bounds but don't early return (must hit all barriers) - let is_valid = output_row < params.n; - - var local_sum = 0.0; - let k_vec = params.k / VECTOR_WIDTH; - - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < k_vec; k_tile += TILE_K/VECTOR_WIDTH) { - let tile_size = min(TILE_K/VECTOR_WIDTH, k_vec - k_tile); - - // Cooperatively load vector tile into shared memory (all threads) - // Note: In this case, src0 is the vector input - for (var i = thread_id; i < tile_size; i += WORKGROUP_SIZE) { - let k_idx = (k_tile + i) * VECTOR_WIDTH; - if (k_idx < params.k) { - let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + - src02_idx * params.stride_02 + k_idx; - shared_vector[i] = vec4(src0[src0_idx / VECTOR_WIDTH]); - } - } - - workgroupBarrier(); - - // Each sub-group of 16 threads computes its own output - // thread_in_group = 0-15, provides stride for this sub-group - let threads_per_output = WORKGROUP_SIZE / OUTPUTS_PER_WG; // 16 - - if (is_valid) { - for (var i = thread_in_group; i < tile_size; i += threads_per_output) { - let k_idx = (k_tile + i) * VECTOR_WIDTH; - // Removed redundant k_idx < params.k check (guaranteed by tile_size) - let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + - src12_idx * params.stride_12 + output_row * params.stride_11 + k_idx; - let a = shared_vector[i]; // from src0 - let b = src1[src1_idx / VECTOR_WIDTH]; - local_sum += dot(a, b); - } - } - - workgroupBarrier(); - } - - // Handle remaining elements (K % VECTOR_WIDTH) - if (is_valid) { - let k_remainder_start = k_vec * VECTOR_WIDTH; - if (thread_in_group < (params.k - k_remainder_start)) { - let k_idx = k_remainder_start + thread_in_group; - let src0_idx = params.offset_src0 + src03_idx * params.stride_03 + - src02_idx * params.stride_02 + k_idx; - let src1_idx = params.offset_src1 + src13_idx * params.stride_13 + - src12_idx * params.stride_12 + output_row * params.stride_11 + k_idx; - let vec_idx = k_idx / VECTOR_WIDTH; - let elem_idx = k_idx % VECTOR_WIDTH; - let a_vec = src0[src0_idx / VECTOR_WIDTH]; - let b_vec = src1[src1_idx / VECTOR_WIDTH]; - local_sum += f32(a_vec[elem_idx]) * b_vec[elem_idx]; - } - } - - // Store partial sums and reduce within each sub-group (16 threads per output) - partial_sums[thread_id] = local_sum; - workgroupBarrier(); - - // Reduce within each sub-group: 16 threads → 1 result - // Each sub-group occupies 16 consecutive slots in partial_sums - let group_base = output_offset * (WORKGROUP_SIZE / OUTPUTS_PER_WG); // 0, 16, 32, ..., 240 - - // Reduction for 16 threads: 16 → 8 → 4 → 2 → 1 (loop version for correctness) - for (var stride = 8u; stride > 0u; stride = stride / 2u) { - if (thread_in_group < stride) { - partial_sums[group_base + thread_in_group] += partial_sums[group_base + thread_in_group + stride]; - } - workgroupBarrier(); - } - - // First thread of each sub-group has the result - // Threads 0, 16, 32, 48, ... 240 hold the 16 output values - if (thread_id == 0u && output_row_base < params.n) { - // Gather 16 results and write as 4 vec4s - let result_vec0 = vec4(partial_sums[0], partial_sums[16], partial_sums[32], partial_sums[48]); - let result_vec1 = vec4(partial_sums[64], partial_sums[80], partial_sums[96], partial_sums[112]); - let result_vec2 = vec4(partial_sums[128], partial_sums[144], partial_sums[160], partial_sums[176]); - let result_vec3 = vec4(partial_sums[192], partial_sums[208], partial_sums[224], partial_sums[240]); - - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + - dst2_idx * dst2_stride + output_row_base * params.m; - dst[dst_idx / VECTOR_WIDTH] = result_vec0; - dst[dst_idx / VECTOR_WIDTH + 1u] = result_vec1; - dst[dst_idx / VECTOR_WIDTH + 2u] = result_vec2; - dst[dst_idx / VECTOR_WIDTH + 3u] = result_vec3; - } - } -} \ No newline at end of file From f2e187c7f26cdf0f4a9284db8ecbb2d74ffb7213 Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:48:39 -0500 Subject: [PATCH 22/40] Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 25 +++++++--- .../{set_rows.wgsl => set_rows.tmpl.wgsl} | 46 ++++++++++++++++--- 2 files changed, 58 insertions(+), 13 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl => set_rows.tmpl.wgsl} (68%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b4558a9e3f1d2..353c7729bd1f8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,7 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline; + webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type @@ -766,10 +766,21 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + size_t max_wg_size = ctx->max_wg_size_x; + + int vectorized = src->ne[0] % 4 == 0; + webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; + // if not evenly divisble by 4, use the non-vectorized version + uint32_t threads; + if (vectorized) { + threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + } else { + threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; + } + + uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -1620,8 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16, + "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec, + "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl similarity index 68% rename from ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl index 3567713dc215c..4a6d819d3b145 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -1,13 +1,38 @@ +#define(VARIANTS) + +[ + { + "SHADER_SUFFIX": "f16_vec", + "REPLS": { + "TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE": 4 + } + }, + { + "SHADER_SUFFIX": "f16", + "REPLS": { + "TYPE" : "f32", + "DST_TYPE": "f16", + "VEC_SIZE": 1 + } + } +] + +#end(VARIANTS) + +#define(SHADER) + enable f16; @group(0) @binding(0) -var src: array; +var src: array<{{TYPE}}>; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array; +var dst: array<{{DST_TYPE}}>; @group(0) @binding(3) var error: atomic; @@ -47,10 +72,14 @@ var params: Params; override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) { return; } - var i = gid.x; + + // getting the row from gid + let elems_per_row = params.ne0 / {{VEC_SIZE}}; + var i = gid.x / elems_per_row; + let i_src3 = i / (params.ne2 * params.n_rows); i = i % (params.ne2 * params.n_rows); @@ -75,7 +104,10 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - for (var i: u32 = 0; i < params.ne0; i++) { - dst[i_dst_row + i] = f16(src[i_src_row + i]); - } + // starts at what element of that row? + let col_idx = (gid.x % elems_per_row); + dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); } + +#end(SHADER) + From 51aae63b49034ff8171ab623f66ac47a3441628a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 13:00:19 -0700 Subject: [PATCH 23/40] Comment on dawn toggles --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 353c7729bd1f8..b4a9f6d579b89 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2149,6 +2149,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif + // Enable Dawn-specific toggles to increase native performance + // TODO: Don't enable for WASM builds, they won't have an effect anyways + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", "disable_polyfills_on_integer_div_and_mod" }; const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; From 9edfcc9d67cbd4ca8fbcbb04f8b5748fc96f12ab Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 14:30:41 -0700 Subject: [PATCH 24/40] Working subgroup matrix code for (semi)generic sizes --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 142 ++++++++++++------ .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 4 +- .../mul_mat_subgroup_matrix.tmpl.wgsl | 93 ++++++------ 3 files changed, 148 insertions(+), 91 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f8ab01abc6fba..286bc2b59ec61 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -355,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ +// Process a WGSL shader string, replacing tokens of the form {{KEY}} with +// the corresponding values provided in `repls`. +static std::string ggml_webgpu_process_shader_repls(const char * src, + const std::vector> & repls) { + if (!src) { + return std::string(); + } + std::string s = src; + for (const auto & kv : repls) { + std::string token = "{{" + kv.first + "}}"; + size_t pos = 0; + while ((pos = s.find(token, pos)) != std::string::npos) { + s.replace(pos, token.length(), kv.second); + pos += kv.second.length(); + } + } + return s; +} + static void ggml_webgpu_create_pipeline(wgpu::Device & device, webgpu_pipeline & pipeline, const char * shader_code, @@ -1749,40 +1768,45 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); if (webgpu_ctx->supports_subgroup_matrix) { - std::vector mul_mat_sg_mat_constants(7); - mul_mat_sg_mat_constants[0].key = "TILE_K"; - mul_mat_sg_mat_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_sg_mat_constants[1].key = "SUBGROUP_M"; - mul_mat_sg_mat_constants[1].value = WEBGPU_MUL_MAT_SUBGROUP_M; - mul_mat_sg_mat_constants[2].key = "SUBGROUP_N"; - mul_mat_sg_mat_constants[2].value = WEBGPU_MUL_MAT_SUBGROUP_N; - mul_mat_sg_mat_constants[3].key = "SUBGROUP_MATRIX_M_SIZE"; - mul_mat_sg_mat_constants[3].value = static_cast(webgpu_ctx->subgroup_matrix_config.M); - mul_mat_sg_mat_constants[4].key = "SUBGROUP_MATRIX_N_SIZE"; - mul_mat_sg_mat_constants[4].value = static_cast(webgpu_ctx->subgroup_matrix_config.N); - mul_mat_sg_mat_constants[5].key = "SUBGROUP_SIZE"; - mul_mat_sg_mat_constants[5].value = static_cast(webgpu_ctx->subgroup_size); - mul_mat_sg_mat_constants[6].key = "SUBGROUP_MATRIX_K_SIZE"; - mul_mat_sg_mat_constants[6].value = static_cast(webgpu_ctx->subgroup_matrix_config.K); + std::vector> sg_matrix_repls; + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); + sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K)); + + std::string proc_mul_mat_subgroup_matrix_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32, - "mul_mat_subgroup_matrix_f32_f32", mul_mat_sg_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32_vec, - "mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32, - "mul_mat_subgroup_matrix_f16_f32", mul_mat_sg_mat_constants); + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f32_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32_vec, - "mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_sg_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16, - "mul_mat_subgroup_matrix_f16_f16", mul_mat_sg_mat_constants); + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec, - "mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_sg_mat_constants); + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f16_vec"); } else { std::vector mul_mat_reg_tile_constants(3); mul_mat_reg_tile_constants[0].key = "TILE_K"; @@ -1792,20 +1816,42 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + std::vector> reg_repls; + reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M)); + reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N)); + + // Process each reg-tile shader with tile replacements. + // Keep the processed strings in-scope so .c_str() remains valid. + std::string proc_mul_mat_reg_tile_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), + "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec, + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), + "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec, + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), + "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec, + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); } @@ -2354,18 +2400,30 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t std::cout << " Result Type: " << static_cast(config.resultComponentType) << "\n"; } - ctx->subgroup_matrix_config = *subgroup_matrix_configs.configs; wgpu::SupportedFeatures features; ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } + } // For subgroup matrix code to be workable, we really need a consistent subgroup size. // Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices // where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values. // Therefore, hardcoding the subgroup size to 32 for now for development. - ctx->subgroup_size = 32; - ctx->supports_subgroup_matrix = ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + ctx->subgroup_size = 32; + ctx->supports_subgroup_matrix = + valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index 5cafa7c4c08a4..2efe48009916c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -187,8 +187,8 @@ fn get_local_m(thread_id: u32) -> u32 { // Warning: cannot be overrides, must match values in ggml-webgpu.cpp // TILE_M must be multiple of 4 for vec4 loads -const TILE_M = 4u; -const TILE_N = 4u; +const TILE_M = {{WEBGPU_TILE_M}}u; +const TILE_N = {{WEBGPU_TILE_N}}u; override WORKGROUP_SIZE_M: u32; override WORKGROUP_SIZE_N: u32; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index c8757bf207251..41d2aa0befd0b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -72,10 +72,10 @@ fn zero_val_src0() -> {{SRC0_TYPE}} { } fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - src0_shmem[idx] = f16(val.x); - src0_shmem[idx + 1] = f16(val.y); - src0_shmem[idx + 2] = f16(val.z); - src0_shmem[idx + 3] = f16(val.w); + shmem[idx] = f16(val.x); + shmem[idx + 1] = f16(val.y); + shmem[idx + 2] = f16(val.z); + shmem[idx + 3] = f16(val.w); } fn zero_val_src1() -> {{SRC1_TYPE}} { @@ -83,18 +83,18 @@ fn zero_val_src1() -> {{SRC1_TYPE}} { } fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - src1_shmem[idx] = f16(val.x); - src1_shmem[idx + 1] = f16(val.y); - src1_shmem[idx + 2] = f16(val.z); - src1_shmem[idx + 3] = f16(val.w); + shmem[idx] = f16(val.x); + shmem[idx + 1] = f16(val.y); + shmem[idx + 2] = f16(val.z); + shmem[idx + 3] = f16(val.w); } fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( - f32(src0_shmem[shmem_idx]), - f32(src0_shmem[shmem_idx + 1]), - f32(src0_shmem[shmem_idx + 2]), - f32(src0_shmem[shmem_idx + 3]) + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) ); } #enddecl(SHMEM_VEC) @@ -105,7 +105,7 @@ fn zero_val_src0() -> {{SRC0_TYPE}} { } fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - src0_shmem[idx] = f16(val); + shmem[idx] = f16(val); } fn zero_val_src1() -> {{SRC1_TYPE}} { @@ -113,11 +113,11 @@ fn zero_val_src1() -> {{SRC1_TYPE}} { } fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - src1_shmem[idx] = f16(val); + shmem[idx] = f16(val); } fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = f32(src0_shmem[shmem_idx]); + dst[dst_idx] = f32(shmem[shmem_idx]); } #enddecl(SHMEM_SCALAR) @@ -155,37 +155,36 @@ struct MulMatParams { DECLS -override SUBGROUP_M: u32; -override SUBGROUP_MATRIX_M_SIZE: u32; -override SUBGROUP_N: u32; -override SUBGROUP_MATRIX_N_SIZE: u32; -override SUBGROUP_SIZE: u32; +// Note: These are string interpolated at build time, cannot use override constants due to limitations in +// current Dawn version type definitions/matrix load requirements for constant memory sizes. +const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; +const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; +const SUBGROUP_SIZE = {{WEBGPU_SUBGROUP_SIZE}}u; -// Note: must match values in ggml-webgpu.cpp -const SUBGROUP_MATRIX_M = 4u; -const SUBGROUP_MATRIX_N = 2u; +const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; +const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; +const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; -override TILE_K: u32; -// Note: we assume TILE_K is divisible by SUBGROUP_MATRIX_K; -override SUBGROUP_MATRIX_K_SIZE: u32; +const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; +const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; -override WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -override WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; +const TILE_K = {{WEBGPU_TILE_K}}u; -override TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; -override TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -override TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; -override SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; +const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; -// We reuse src0_shmem for accumulation matrices -override SHMEM_SIZE = max(TILE_SRC0_SHMEM, SG_MAT_ACCUM_SHMEM); +const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse shmem for accumulation matrices +const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); // Note: apparently current dawn doesn't like override constant shared memory size along with subgroup matrix loads -//var src0_shmem: array; -//var src1_shmem: array; -var src0_shmem: array; -var src1_shmem: array; +var shmem: array; +//var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -221,7 +220,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -249,7 +248,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, zero_val_src1(), src1[src1_idx/{{VEC_SIZE}}], global_n < params.n && global_k < params.k); - store_src1_shmem(src1_val, elem_idx); + store_src1_shmem(src1_val, TILE_SRC0_SHMEM + elem_idx); } workgroupBarrier(); @@ -257,10 +256,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3, for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; - var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - src0_sg_mats[m] = subgroupMatrixLoad>( - &src0_shmem, + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, false, TILE_K @@ -269,9 +268,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - let src1_sg_mat = subgroupMatrixLoad>( - &src1_shmem, - src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, true, TILE_K ); @@ -298,7 +297,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; let out_base = local_row * WG_TILE_STRIDE + local_col; - subgroupMatrixStore(&src0_shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); } } From f0cfae49d637ea48b901a72c0b344ba9393f38e8 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 15:48:00 -0700 Subject: [PATCH 25/40] Remove some comments --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 +-- ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b4a9f6d579b89..70e3013537b2d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -248,7 +248,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized) + webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type @@ -770,7 +770,6 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, int vectorized = src->ne[0] % 4 == 0; webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; - // if not evenly divisble by 4, use the non-vectorized version uint32_t threads; if (vectorized) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl index 4a6d819d3b145..fca3be6bc27ed 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl @@ -104,7 +104,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - // starts at what element of that row? let col_idx = (gid.x % elems_per_row); dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); } From cf0c5364d45ab06d8f959847c9b6553eca5e69db Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 17:26:41 -0700 Subject: [PATCH 26/40] Cleanup code --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 104 ++++++------------ .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 89 +++------------ .../mul_mat_subgroup_matrix.tmpl.wgsl | 58 +++------- 3 files changed, 66 insertions(+), 185 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 97bf2348bb02d..e1343b34ae5ea 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -74,16 +74,16 @@ // For operations which process a row in parallel, this seems like a reasonable default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 -// Matrix multiplication fast path parameters - -// Warning: must match values in mul_mat_fast.wgsl -#define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 4 +// Matrix multiplication parameters +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_WG_SIZE_M 16 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 +// Subgroup matrix parameters // The number of subgroups in the M dimension #define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension @@ -92,7 +92,7 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 -// GEMV constants +// gemv parameters #define WEBGPU_GEMV_WG_SIZE 256 // Must be multiple of 4 to work with vectorized paths, and must divide gemv wg size #define WEBGPU_GEMV_OUTPUTS_PER_WG 16 @@ -948,60 +948,37 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; uint32_t wg_y = 1; - bool use_gemv = false; - if (dst->ne[1] == 1) { - switch (src1->type) { - case GGML_TYPE_F16: - use_gemv = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - use_gemv = true; - break; - default: - break; - } - default: - break; - } + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + use_fast = true; + break; + default: + break; + } + break; + default: + break; } - if (use_gemv) { - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = (dst->ne[0] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1) / WEBGPU_GEMV_OUTPUTS_PER_WG; - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / - ctx->limits.maxComputeWorkgroupsPerDimension; - int vectorized = src0->ne[0] % 4 == 0 && src1->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0; - pipeline = ctx->gemv_pipelines[src0->type][src1->type][vectorized]; - - } else { - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } - - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + if (dst->ne[1] == 1) { + pipeline = ctx->gemv_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = (dst->ne[0] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1) / WEBGPU_GEMV_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + ctx->limits.maxComputeWorkgroupsPerDimension; + } else { + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; if (ctx->supports_subgroup_matrix) { @@ -2400,17 +2377,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t info.nextInChain = &subgroup_matrix_configs; ctx->adapter.GetInfo(&info); - // print configs - for (int i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - std::cout << "ggml_webgpu: Subgroup Matrix Config " << i << ":\n"; - std::cout << " M: " << config.M << "\n"; - std::cout << " N: " << config.N << "\n"; - std::cout << " K: " << config.K << "\n"; - std::cout << " Component Type: " << static_cast(config.componentType) << "\n"; - std::cout << " Result Type: " << static_cast(config.resultComponentType) << "\n"; - } - wgpu::SupportedFeatures features; ctx->adapter.GetFeatures(&features); // we require f16 support diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index 2efe48009916c..4b1a5a4040d2b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -8,7 +8,7 @@ "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F32_VEC", "SRC1_F32_VEC"] + "DECLS": ["VEC"] }, { "SHADER_SUFFIX": "f32_f32", @@ -18,7 +18,7 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F32", "SRC1_F32"] + "DECLS": ["SCALAR"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -28,7 +28,7 @@ "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F16_VEC", "SRC1_F32_VEC"] + "DECLS": ["VEC"] }, { "SHADER_SUFFIX": "f16_f32", @@ -38,7 +38,7 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F16", "SRC1_F32"] + "DECLS": ["SCALAR"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -48,7 +48,7 @@ "DST_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SRC0_F16_VEC", "SRC1_F16_VEC"] + "DECLS": ["VEC"] }, { "SHADER_SUFFIX": "f16_f16", @@ -58,7 +58,7 @@ "DST_TYPE" : "f32", "VEC_SIZE" : "1", }, - "DECLS": ["SRC0_F16", "SRC1_F16"] + "DECLS": ["SCALAR"] } ] @@ -66,85 +66,25 @@ #define(DECLS) -#decl(SRC0_F32_VEC) -fn zero_val_src0() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} -#enddecl(SRC0_F32_VEC) - -#decl(SRC0_F32) -fn zero_val_src0() -> f32 { - return 0.0; -} -#enddecl(SRC0_F32) - -#decl(SRC0_F16_VEC) -fn zero_val_src0() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} -#enddecl(SRC0_F16_VEC) - -#decl(SRC0_F16) -fn zero_val_src0() -> f16 { - return 0.0; -} -#enddecl(SRC0_F16) - -#decl(SRC1_F32_VEC) -fn zero_val_src1() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} - +#decl(VEC) fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { - return dot(vec4(src0_val), src1_val); -} -#enddecl(SRC1_F32_VEC) - -#decl(SRC1_F32) -fn zero_val_src1() -> f32 { - return 0.0; -} - -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return acc[tm][tn]; -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f32) -> f32 { - return f32(src0_val) * src1_val; -} -#enddecl(SRC1_F32) - -#decl(SRC1_F16_VEC) -fn zero_val_src1() -> vec4 { - return vec4(0.0, 0.0, 0.0, 0.0); -} - -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(acc[tm][tn], f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: vec4) -> f32 { +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { return dot(vec4(src0_val), vec4(src1_val)); } -#enddecl(SRC1_F16_VEC) - -#decl(SRC1_F16) -fn zero_val_src1() -> f16 { - return 0.0; -} +#enddecl(VEC) +#decl(SCALAR) fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { return acc[tm][tn]; } -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: f16) -> f32 { +fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { return f32(src0_val) * f32(src1_val); } -#enddecl(SRC1_F16) +#enddecl(SCALAR) #end(DECLS) @@ -185,7 +125,6 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } -// Warning: cannot be overrides, must match values in ggml-webgpu.cpp // TILE_M must be multiple of 4 for vec4 loads const TILE_M = {{WEBGPU_TILE_M}}u; const TILE_N = {{WEBGPU_TILE_N}}u; @@ -248,7 +187,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let global_k = k_outer + tile_k; let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob - zero_val_src0(), + {{SRC0_TYPE}}(0.0), src0[src0_idx/{{VEC_SIZE}}], global_m < params.m && global_k < params.k); } @@ -261,7 +200,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; src1_shmem[elem_idx/{{VEC_SIZE}}] = select( - zero_val_src1(), + {{SRC1_TYPE}}(0.0), src1[src1_idx/{{VEC_SIZE}}], global_n < params.n && global_k < params.k); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 41d2aa0befd0b..f3ff0b0dc617e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -6,6 +6,7 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -16,6 +17,7 @@ "SRC0_TYPE" : "f32", "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, "DECLS": ["SHMEM_SCALAR"] @@ -26,6 +28,7 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -36,6 +39,7 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, "DECLS": ["SHMEM_SCALAR"] @@ -46,6 +50,7 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, "DECLS": ["SHMEM_VEC"] @@ -56,6 +61,7 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f16", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, "DECLS": ["SHMEM_SCALAR"] @@ -67,26 +73,11 @@ #define(DECLS) #decl(SHMEM_VEC) -fn zero_val_src0() -> {{SRC0_TYPE}} { - return {{SRC0_TYPE}}(0.0, 0.0, 0.0, 0.0); -} - -fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - shmem[idx] = f16(val.x); - shmem[idx + 1] = f16(val.y); - shmem[idx + 2] = f16(val.z); - shmem[idx + 3] = f16(val.w); -} - -fn zero_val_src1() -> {{SRC1_TYPE}} { - return {{SRC1_TYPE}}(0.0, 0.0, 0.0, 0.0); -} - -fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - shmem[idx] = f16(val.x); - shmem[idx + 1] = f16(val.y); - shmem[idx + 2] = f16(val.z); - shmem[idx + 3] = f16(val.w); +fn store_shmem(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; } fn store_dst(shmem_idx: u32, dst_idx: u32) { @@ -100,20 +91,8 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { #enddecl(SHMEM_VEC) #decl(SHMEM_SCALAR) -fn zero_val_src0() -> {{SRC0_TYPE}} { - return 0.0; -} - -fn store_src0_shmem(val: {{SRC0_TYPE}}, idx: u32) { - shmem[idx] = f16(val); -} - -fn zero_val_src1() -> {{SRC1_TYPE}} { - return 0.0; -} - -fn store_src1_shmem(val: {{SRC1_TYPE}}, idx: u32) { - shmem[idx] = f16(val); +fn store_shmem(val: f16, idx: u32) { + shmem[idx] = val; } fn store_dst(shmem_idx: u32, dst_idx: u32) { @@ -182,9 +161,7 @@ const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROU // We reuse shmem for accumulation matrices const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); -// Note: apparently current dawn doesn't like override constant shared memory size along with subgroup matrix loads var shmem: array; -//var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id: vec3, @@ -231,10 +208,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let global_k = k_outer + tile_k; let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - zero_val_src0(), + {{SRC0_TYPE}}(0.0), src0[src0_idx/{{VEC_SIZE}}], global_m < params.m && global_k < params.k); - store_src0_shmem(src0_val, elem_idx); + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); } for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { @@ -245,10 +222,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - zero_val_src1(), + {{SRC1_TYPE}}(0.0), src1[src1_idx/{{VEC_SIZE}}], global_n < params.n && global_k < params.k); - store_src1_shmem(src1_val, TILE_SRC0_SHMEM + elem_idx); + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); } workgroupBarrier(); @@ -285,7 +262,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - // Stage the subgroup matrix tiles into shared memory // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; From 71c7a4a8e21e66b1f8b9883a4246c377c139d025 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 21:08:53 -0700 Subject: [PATCH 27/40] Update dawn version and move to portable subgroup size --- .github/workflows/build.yml | 16 ++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 5 +- .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 8 +- .../mul_mat_subgroup_matrix.tmpl.wgsl | 74 ++++++++++--------- 4 files changed, 55 insertions(+), 48 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15e1133095213..df013383d7d28 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -161,15 +161,15 @@ jobs: - name: Dawn Dependency id: dawn-depends run: | - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip -d dawn - name: Build id: cmake_build @@ -521,15 +521,15 @@ jobs: id: dawn-depends run: | sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip -d dawn - name: Build id: cmake_build diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e1343b34ae5ea..4b9eca195931b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1755,7 +1755,7 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { if (webgpu_ctx->supports_subgroup_matrix) { std::vector> sg_matrix_repls; - sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); + sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K)); sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M)); sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N)); @@ -2398,7 +2398,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices // where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values. // Therefore, hardcoding the subgroup size to 32 for now for development. - ctx->subgroup_size = 32; + ctx->subgroup_size = info.subgroupMaxSize; ctx->supports_subgroup_matrix = valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); @@ -2406,6 +2406,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; if (ctx->supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index 4b1a5a4040d2b..8a20d6976e8fa 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -141,22 +141,20 @@ var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>; var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) global_id: vec3, +fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3) { let thread_id = local_id.x; let local_m = get_local_m(thread_id); let local_n = get_local_n(thread_id); - let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_linear / wg_per_matrix; + let batch_idx = wg_id.x / wg_per_matrix; - let wg_in_batch = wg_linear % wg_per_matrix; + let wg_in_batch = wg_id.x % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index f3ff0b0dc617e..3e55be6feafc6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -105,6 +105,7 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { #define(SHADER) diagnostic(off, chromium.subgroup_matrix_uniformity); enable f16; +enable subgroups; enable chromium_experimental_subgroup_matrix; struct MulMatParams { @@ -138,7 +139,11 @@ DECLS // current Dawn version type definitions/matrix load requirements for constant memory sizes. const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; -const SUBGROUP_SIZE = {{WEBGPU_SUBGROUP_SIZE}}u; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; + +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; @@ -152,7 +157,7 @@ const TILE_K = {{WEBGPU_TILE_K}}u; const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; -const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE; +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -164,7 +169,7 @@ const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) global_id: vec3, +fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @builtin(subgroup_id) subgroup_id: u32) { @@ -172,15 +177,13 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let subgroup_m = subgroup_id % SUBGROUP_M; let subgroup_n = subgroup_id / SUBGROUP_M; - let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE; - let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_linear / wg_per_matrix; + let batch_idx = wg_id.x / wg_per_matrix; - let wg_in_batch = wg_linear % wg_per_matrix; + let wg_in_batch = wg_id.x % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; @@ -230,29 +233,32 @@ fn main(@builtin(global_invocation_id) global_id: vec3, workgroupBarrier(); - for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + if (subgroup_id < EXPECTED_SUBGROUPS) { - let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; - var src0_sg_mats: array, SUBGROUP_MATRIX_M>; - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - src0_sg_mats[m] = subgroupMatrixLoad>( - &shmem, - src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, - false, - TILE_K - ); - } + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { - let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; - for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - let src1_sg_mat = subgroupMatrixLoad>( - &shmem, - TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, - true, - TILE_K - ); + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } } } } @@ -268,12 +274,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3, let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; - for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { - for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { - let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; - let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; - let out_base = local_row * WG_TILE_STRIDE + local_col; - subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } } } From c73893e5dc8cc07cac8ef8c26204a947edfda271 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 21:25:18 -0700 Subject: [PATCH 28/40] Try to fix new dawn release --- .github/workflows/build.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index df013383d7d28..36084c55078ef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -169,7 +169,8 @@ jobs: curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - unzip artifact.zip -d dawn + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -529,7 +530,8 @@ jobs: curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - unzip artifact.zip -d dawn + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build From f538ca364f95ec143211c720f49cd01badaa4e1f Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 27 Oct 2025 21:33:52 -0700 Subject: [PATCH 29/40] Update subgroup size comment --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4b9eca195931b..d3b505acd06d5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2394,10 +2394,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t break; } } - // For subgroup matrix code to be workable, we really need a consistent subgroup size. - // Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices - // where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values. - // Therefore, hardcoding the subgroup size to 32 for now for development. + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. ctx->subgroup_size = info.subgroupMaxSize; ctx->supports_subgroup_matrix = valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); From f5001d8fb6c074e88e7a953abd493b028f16d44a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Oct 2025 16:12:13 -0700 Subject: [PATCH 30/40] Only check for subgroup matrix configs if they are supported --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d3b505acd06d5..fff090ab54259 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2374,7 +2374,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::AdapterInfo info{}; wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; - info.nextInChain = &subgroup_matrix_configs; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2384,22 +2386,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; - for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && - config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->subgroup_matrix_config = config; - valid_subgroup_matrix_config = true; - break; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } } } // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = - valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, From 844ba40c5943039b4673e6c72ce6bbc67f5f8660 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Oct 2025 19:09:16 -0700 Subject: [PATCH 31/40] Add toggles for subgroup matrix/f16 support on nvidia+vulkan --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index fff090ab54259..a6b53da75b209 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2356,7 +2356,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; wgpu::RequestAdapterOptions options = {}; + options.nextInChain = &adapterTogglesDesc; ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { From d426436c7babca4325898460a1221adf29ebe7de Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Oct 2025 20:55:41 -0700 Subject: [PATCH 32/40] Make row/col naming consistent --- .../ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl | 2 +- .../wgsl-shaders/mat_mul_decls.tmpl | 19 +++++++++++++++++++ .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 14 +++++++------- .../mul_mat_subgroup_matrix.tmpl.wgsl | 12 ++++++------ 4 files changed, 33 insertions(+), 14 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl index 3964352f28ce8..32e6b0361523b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl @@ -116,7 +116,7 @@ struct MulMatParams { }; @group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) @group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) @group(0) @binding(3) var params: MulMatParams; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl new file mode 100644 index 0000000000000..825ec23d4ebe7 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl @@ -0,0 +1,19 @@ +#decl(LD_SHMEM) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let src0_val = select( // taking a slight performance hit to avoid oob + {{SRC0_TYPE}}(0.0), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + } +} + +#enddecl(LD_SHMEM) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index 8a20d6976e8fa..e9a4e6306e69f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -158,8 +158,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; - let output_row_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; - let output_col_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; @@ -230,12 +230,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; for (var tn = 0u; tn < TILE_N; tn++) { - let global_row = output_row_base + tn; - if (global_row < params.n) { + let global_col = output_col_base + tn; + if (global_col < params.n) { for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { - let global_col = output_col_base + tm; - if (global_col < params.m) { - let dst_idx = dst_batch_offset + global_row * params.m + global_col; + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 3e55be6feafc6..eda4d04ce3d1d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -291,18 +291,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let tile_rows = WG_N_SG_TILE_SIZE; let tile_cols = WG_M_SG_TILE_SIZE; let total_tile_elems = tile_rows * tile_cols; - let tile_dst_row_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - let tile_dst_col_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let local_row = idx / WG_TILE_STRIDE; - let local_col = idx % WG_TILE_STRIDE; + let local_row = idx % WG_TILE_STRIDE; + let local_col = idx / WG_TILE_STRIDE; let global_row = tile_dst_row_base + local_row; let global_col = tile_dst_col_base + local_col; - if (global_row < params.n && global_col < params.m) { - let dst_idx = dst_batch_offset + global_row * params.m + global_col; + if (global_col < params.n && global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; store_dst(idx, dst_idx/{{VEC_SIZE}}); } } From a46d0933206795229ecc32651d687ac9517440d6 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 29 Oct 2025 15:36:52 -0700 Subject: [PATCH 33/40] Refactor shared memory loading --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 9 ++- .../wgsl-shaders/mat_mul_decls.tmpl | 47 ++++++++++- .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 79 +++++++----------- .../mul_mat_subgroup_matrix.tmpl.wgsl | 80 ++++--------------- 5 files changed, 97 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a6b53da75b209..7d70335838e3e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -77,9 +77,9 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 4 -#define WEBGPU_MUL_MAT_WG_SIZE_M 16 +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaeca0f..ed8068d416ebf 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile): except ValueError: decls_map = {} - with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: - common_decls = f.read() - decls_map.update(parse_decls(common_decls)) + for fname in sorted(os.listdir(input_dir)): + if fname.endswith(".tmpl"): + tmpl_path = os.path.join(input_dir, fname) + with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: + decls = f_tmpl.read() + decls_map.update(parse_decls(decls)) shader_template = extract_block(text, "SHADER") for variant in variants: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl index 825ec23d4ebe7..7f84fdcf9e285 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl @@ -1,4 +1,32 @@ -#decl(LD_SHMEM) +#decl(SHMEM_VEC) +fn store_shmem(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} + +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#enddecl(SHMEM_VEC) + +#decl(SHMEM_SCALAR) +fn store_shmem(val: f16, idx: u32) { + shmem[idx] = val; +} + +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#enddecl(SHMEM_SCALAR) + +#decl(INIT_SHMEM_FLOAT) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { @@ -15,5 +43,20 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } -#enddecl(LD_SHMEM) +fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_n = offset_n + tile_n; + let global_k = k_outer + tile_k; + let src1_idx = batch_offset + global_n * params.stride_11 + global_k; + let src1_val = select( + {{SRC1_TYPE}}(0.0), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + } +} + +#enddecl(INIT_SHMEM_FLOAT) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index e9a4e6306e69f..fd33209abee66 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -6,9 +6,10 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f32_f32", @@ -16,9 +17,10 @@ "SRC0_TYPE" : "f32", "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -26,9 +28,10 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32", @@ -36,9 +39,10 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -46,9 +50,10 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16", @@ -56,9 +61,10 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f16", "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] } ] @@ -67,22 +73,14 @@ #define(DECLS) #decl(VEC) -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return dot(vec4(src0_val), vec4(src1_val)); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); } #enddecl(VEC) #decl(SCALAR) -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return acc[tm][tn]; -} - -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(src0_val) * f32(src1_val); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return f32(acc[tm][tn]); } #enddecl(SCALAR) @@ -137,8 +135,7 @@ override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; -var src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>; -var src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>; +var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -174,52 +171,34 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - var acc: array, TILE_M>; + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + var acc: array, TILE_M>; - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = wg_m * WORKGROUP_SIZE_M * TILE_M + tile_m; - let global_k = k_outer + tile_k; - let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; - src0_shmem[elem_idx/{{VEC_SIZE}}] = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], - global_m < params.m && global_k < params.k); - } + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let tile_n = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_n = wg_n * WORKGROUP_SIZE_N * TILE_N + tile_n; - let global_k = k_outer + tile_k; - - let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; - src1_shmem[elem_idx/{{VEC_SIZE}}] = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], - global_n < params.n && global_k < params.k); - } + // see mat_mul_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); workgroupBarrier(); let k_end = min(TILE_K, params.k - k_outer); - for (var k_inner = 0u; k_inner < k_end; k_inner += {{VEC_SIZE}}) { - var src0_tile: array<{{SRC0_TYPE}}, TILE_M>; + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; for (var tm = 0u; tm < TILE_M; tm++) { let src0_m = local_m * TILE_M + tm; let src0_idx = k_inner + src0_m * TILE_K; - src0_tile[tm] = src0_shmem[src0_idx/{{VEC_SIZE}}]; + src0_tile[tm] = shmem[src0_idx]; } for (var tn = 0u; tn < TILE_N; tn++) { let src1_n = local_n * TILE_N + tn; let src1_idx = src1_n * TILE_K + k_inner; - let src1_vec = src1_shmem[src1_idx/{{VEC_SIZE}}]; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += mul_acc(src0_tile[tm], src1_vec); + acc[tm][tn] += src0_tile[tm] * src1_val; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index eda4d04ce3d1d..715a16bfc2339 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -9,7 +9,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC"] + "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f32_f32", @@ -20,7 +20,7 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR"] + "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -31,7 +31,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC"] + "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32", @@ -42,7 +42,7 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR"] + "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -53,7 +53,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC"] + "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16", @@ -64,44 +64,12 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR"] + "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] } ] #end(VARIANTS) -#define(DECLS) - -#decl(SHMEM_VEC) -fn store_shmem(val: vec4, idx: u32) { - shmem[idx] = val.x; - shmem[idx + 1] = val.y; - shmem[idx + 2] = val.z; - shmem[idx + 3] = val.w; -} - -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = vec4( - f32(shmem[shmem_idx]), - f32(shmem[shmem_idx + 1]), - f32(shmem[shmem_idx + 2]), - f32(shmem[shmem_idx + 3]) - ); -} -#enddecl(SHMEM_VEC) - -#decl(SHMEM_SCALAR) -fn store_shmem(val: f16, idx: u32) { - shmem[idx] = val; -} - -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = f32(shmem[shmem_idx]); -} -#enddecl(SHMEM_SCALAR) - -#end(DECLS) - #define(SHADER) diagnostic(off, chromium.subgroup_matrix_uniformity); enable f16; @@ -200,36 +168,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE + tile_m; - let global_k = k_outer + tile_k; - let src0_idx = src0_batch_offset + global_m * params.stride_01 + global_k; - let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], - global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); - } - - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { - let tile_n = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE + tile_n; - let global_k = k_outer + tile_k; - - let src1_idx = src1_batch_offset + global_n * params.stride_11 + global_k; - let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], - global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); - } + // see mat_mul_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); workgroupBarrier(); @@ -248,11 +196,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); } - let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { let src1_sg_mat = subgroupMatrixLoad>( &shmem, - TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, true, TILE_K ); From eb7150a1e35cb6000e87c7e23d7dd8e9771dc6ae Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 29 Oct 2025 15:42:52 -0700 Subject: [PATCH 34/40] Move sg matrix stores to correct file --- .../wgsl-shaders/mat_mul_decls.tmpl | 13 ------- .../mul_mat_subgroup_matrix.tmpl.wgsl | 34 +++++++++++++++---- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl index 7f84fdcf9e285..18c15d5d76cfe 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl @@ -5,25 +5,12 @@ fn store_shmem(val: vec4, idx: u32) { shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } - -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = vec4( - f32(shmem[shmem_idx]), - f32(shmem[shmem_idx + 1]), - f32(shmem[shmem_idx + 2]), - f32(shmem[shmem_idx + 3]) - ); -} #enddecl(SHMEM_VEC) #decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } - -fn store_dst(shmem_idx: u32, dst_idx: u32) { - dst[dst_idx] = f32(shmem[shmem_idx]); -} #enddecl(SHMEM_SCALAR) #decl(INIT_SHMEM_FLOAT) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 715a16bfc2339..345a934e2425b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -9,7 +9,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f32_f32", @@ -20,7 +20,7 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -31,7 +31,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32", @@ -42,7 +42,7 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -53,7 +53,7 @@ "SHMEM_TYPE" : "vec4", "VEC_SIZE" : "4", }, - "DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16", @@ -64,12 +64,34 @@ "SHMEM_TYPE" : "f16", "VEC_SIZE" : "1", }, - "DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] } ] #end(VARIANTS) +#define(DECLS) + +#decl(VEC) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#enddecl(SCALAR) + + +#end(DECLS) + #define(SHADER) diagnostic(off, chromium.subgroup_matrix_uniformity); enable f16; From 4ec09e4e3e5683b20528567f62219f915b17f2ec Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 30 Oct 2025 11:48:18 -0700 Subject: [PATCH 35/40] Working q4_0 --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 81 ++++++++++------ .../wgsl-shaders/mat_mul_decls.tmpl | 53 ++++++++++- .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 46 ++++++--- .../mul_mat_subgroup_matrix.tmpl.wgsl | 47 +++++++--- .../{gemv.tmpl.wgsl => mul_mat_vec.tmpl.wgsl} | 93 +++++++++++++++---- 5 files changed, 246 insertions(+), 74 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{gemv.tmpl.wgsl => mul_mat_vec.tmpl.wgsl} (67%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 7d70335838e3e..f34b4c707a627 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -92,11 +92,11 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 -// gemv parameters -#define WEBGPU_GEMV_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide gemv wg size -#define WEBGPU_GEMV_OUTPUTS_PER_WG 16 -#define WEBGPU_GEMV_TILE_K 128 +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 /* End Constants */ @@ -278,7 +278,7 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> gemv_pipelines; // src0_type, src1_type, vectorized + std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized @@ -957,6 +957,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: use_fast = true; break; default: @@ -970,9 +971,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (use_fast) { int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; if (dst->ne[1] == 1) { - pipeline = ctx->gemv_pipelines[src0->type][src1->type][vectorized]; + // We don't support vectorized mul_mat_vec for quantized types + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = (dst->ne[0] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1) / WEBGPU_GEMV_OUTPUTS_PER_WG; + uint32_t output_groups = (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; uint32_t total_wg = output_groups * batches; wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / @@ -1777,6 +1780,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); @@ -1793,6 +1800,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), "mul_mat_subgroup_matrix_f16_f16_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), + "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { std::vector mul_mat_reg_tile_constants(3); mul_mat_reg_tile_constants[0].key = "TILE_K"; @@ -1820,6 +1832,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); std::string proc_mul_mat_reg_tile_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), @@ -1839,28 +1855,37 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), + "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), + "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + } - std::vector gemv_constants(3); - gemv_constants[0].key = "WORKGROUP_SIZE"; - gemv_constants[0].value = WEBGPU_GEMV_WG_SIZE; - gemv_constants[1].key = "TILE_K"; - gemv_constants[1].value = WEBGPU_GEMV_TILE_K; - gemv_constants[2].key = "OUTPUTS_PER_WG"; - gemv_constants[2].value = WEBGPU_GEMV_OUTPUTS_PER_WG; - - webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32, "gemv_f32_f32", gemv_constants); - webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32_vec, "gemv_f32_f32_vec", gemv_constants); - webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32", gemv_constants); - webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32_vec, "gemv_f16_f32_vec", gemv_constants); - webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16, "gemv_f16_f16", gemv_constants); - webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16_vec, "gemv_f16_f16_vec", gemv_constants); + std::vector mul_mat_vec_constants(3); + mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; + mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; + mul_mat_vec_constants[1].key = "TILE_K"; + mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; + mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; + mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl index 18c15d5d76cfe..3b599d8dbb724 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl @@ -13,7 +13,7 @@ fn store_shmem(val: f16, idx: u32) { } #enddecl(SHMEM_SCALAR) -#decl(INIT_SHMEM_FLOAT) +#decl(INIT_SRC0_SHMEM_FLOAT) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { @@ -30,6 +30,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } +#enddecl(INIT_SRC0_SHMEM_FLOAT) + +#decl(INIT_SRC1_SHMEM) + fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_n = elem_idx / TILE_K; @@ -45,5 +49,50 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 } } -#enddecl(INIT_SHMEM_FLOAT) +#enddecl(INIT_SRC1_SHMEM) + +#decl(INIT_SRC0_SHMEM_Q4_0) + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} + +#enddecl(INIT_SRC0_SHMEM_Q4_0) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index fd33209abee66..a33c7793383ef 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -7,9 +7,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f32_f32", @@ -18,9 +18,9 @@ "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -29,9 +29,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f32", @@ -40,9 +40,9 @@ "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -51,9 +51,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f16", @@ -62,9 +62,31 @@ "SRC1_TYPE" : "f16", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] } ] diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 345a934e2425b..2253847eb6804 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -7,9 +7,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f32_f32", @@ -18,9 +18,9 @@ "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -29,9 +29,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f32", @@ -40,9 +40,9 @@ "SRC1_TYPE" : "f32", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -51,9 +51,9 @@ "SRC1_TYPE" : "vec4", "DST_TYPE" : "vec4", "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"] + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] }, { "SHADER_SUFFIX": "f16_f16", @@ -62,9 +62,31 @@ "SRC1_TYPE" : "f16", "DST_TYPE" : "f32", "SHMEM_TYPE" : "f16", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"] + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] } ] @@ -89,7 +111,6 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { } #enddecl(SCALAR) - #end(DECLS) #define(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl similarity index 67% rename from ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl index 32e6b0361523b..ffbb64032854e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gemv.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl @@ -6,9 +6,9 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE": "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "MUL_ACC_FLOAT"] }, { "SHADER_SUFFIX": "f32_f32", @@ -16,9 +16,9 @@ "SRC0_TYPE" : "f32", "SRC1_TYPE" : "f32", "DST_TYPE": "f32", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32_vec", @@ -26,9 +26,9 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE": "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "MUL_ACC_FLOAT"] }, { "SHADER_SUFFIX": "f16_f32", @@ -36,9 +36,9 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f32", "DST_TYPE": "f32", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16_vec", @@ -46,9 +46,9 @@ "SRC0_TYPE" : "vec4", "SRC1_TYPE" : "vec4", "DST_TYPE": "vec4", - "VEC_SIZE" : "4", + "VEC_SIZE" : 4, }, - "DECLS": ["VEC"] + "DECLS": ["VEC", "MUL_ACC_FLOAT"] }, { "SHADER_SUFFIX": "f16_f16", @@ -56,9 +56,19 @@ "SRC0_TYPE" : "f16", "SRC1_TYPE" : "f16", "DST_TYPE": "f32", - "VEC_SIZE" : "1", + "VEC_SIZE" : 1, }, - "DECLS": ["SCALAR"] + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] } ] @@ -67,7 +77,7 @@ #define(DECLS) #decl(VEC) -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); } @@ -80,7 +90,7 @@ fn store_val(group_base: u32) -> vec4 { #enddecl(VEC) #decl(SCALAR) -fn mul_acc(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { return f32(src0_val) * f32(src1_val); } @@ -89,6 +99,55 @@ fn store_val(group_base: u32) -> f32 { } #enddecl(SCALAR) +#decl(MUL_ACC_FLOAT) + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { + let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; + let b = shared_vector[i / {{VEC_SIZE}}]; + local_sum += inner_dot(a, b); + } + return local_sum; +} + +#enddecl(MUL_ACC_FLOAT) + +#decl(MUL_ACC_Q4_0) + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0) * d; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} + +#enddecl(MUL_ACC_Q4_0) + #end(DECLS) #define(SHADER) @@ -180,11 +239,7 @@ fn main( workgroupBarrier(); if (output_row < params.m) { - for (var i = thread_in_group * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(src0_idx_base + k_tile + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; - local_sum += mul_acc(a, b); - } + local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); } workgroupBarrier(); From 97266403d96fc1b5ba1b06e907b34a11e2fa8904 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 30 Oct 2025 21:17:52 -0700 Subject: [PATCH 36/40] Formatting --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 59 ++++++++++--------- ...{mat_mul_decls.tmpl => mul_mat_decls.tmpl} | 1 - .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 2 +- .../mul_mat_subgroup_matrix.tmpl.wgsl | 2 +- 4 files changed, 32 insertions(+), 32 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{mat_mul_decls.tmpl => mul_mat_decls.tmpl} (99%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f34b4c707a627..991ad4dcbde22 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -278,7 +278,8 @@ struct webgpu_context_struct { webgpu_pipeline memset_pipeline; std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized @@ -972,13 +973,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; if (dst->ne[1] == 1) { // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = + (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / ctx->limits.maxComputeWorkgroupsPerDimension; } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; @@ -1861,7 +1863,6 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); - } std::vector mul_mat_vec_constants(3); @@ -1872,20 +1873,20 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -2382,12 +2383,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 - const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; wgpu::DawnTogglesDescriptor adapterTogglesDesc; - adapterTogglesDesc.enabledToggles = adapterEnabledToggles; - adapterTogglesDesc.enabledToggleCount = 2; - wgpu::RequestAdapterOptions options = {}; - options.nextInChain = &adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + wgpu::RequestAdapterOptions options = {}; + options.nextInChain = &adapterTogglesDesc; ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2432,7 +2433,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; + ctx->subgroup_size = info.subgroupMaxSize; ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; // Initialize device diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl similarity index 99% rename from ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 3b599d8dbb724..109ff8d6159e1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -95,4 +95,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #enddecl(INIT_SRC0_SHMEM_Q4_0) - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl index a33c7793383ef..6b1dd26cd9e0d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -200,7 +200,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - // see mat_mul_decls.tmpl + // see mul_mat_decls.tmpl init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl index 2253847eb6804..47c8ce36ab336 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -218,7 +218,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - // see mat_mul_decls.tmpl + // see mul_mat_decls.tmpl init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); From b51edae83fb0319621bca9d62c3e628213bf649e Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sun, 2 Nov 2025 21:54:13 -0800 Subject: [PATCH 37/40] Work with emscripten builds --- common/arg.cpp | 4 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 171 ++++++++++++--------------- 2 files changed, 79 insertions(+), 96 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a25743c899862..13195023a7645 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -39,6 +39,7 @@ #include "http.h" #endif +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -50,8 +51,11 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + // isatty #if defined(_WIN32) #include diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 8fb705be6440c..7787b779f15ee 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -265,9 +265,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - uint32_t subgroup_size; wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. @@ -453,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, // inflight_max may be 0, meaning that we must wait on all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint inflight_threads = ctx->inflight_threads; - uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); @@ -990,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; +#ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = @@ -999,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { +#endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; +#ifndef __EMSCRIPTEN__ } +#endif + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } @@ -1423,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint inflight_threads = ctx->inflight_threads; - uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions @@ -1762,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + std::string proc_mul_mat_f32_f32; + std::string proc_mul_mat_f32_f32_vec; + std::string proc_mul_mat_f16_f32; + std::string proc_mul_mat_f16_f32_vec; + std::string proc_mul_mat_f16_f16; + std::string proc_mul_mat_f16_f16_vec; + std::string proc_mul_mat_q4_0_f32; + std::string proc_mul_mat_q4_0_f32_vec; + + std::vector mul_mat_constants; +#ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::vector> sg_matrix_repls; sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); @@ -1774,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N)); sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K)); - std::string proc_mul_mat_subgroup_matrix_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f32_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f16_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), - "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; +#endif + mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); std::vector> reg_repls; reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M)); reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N)); - // Process each reg-tile shader with tile replacements. - // Keep the processed strings in-scope so .c_str() remains valid. - std::string proc_mul_mat_reg_tile_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), - "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), - "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), - "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), - "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), - "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); +#ifndef __EMSCRIPTEN__ } +#endif + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2412,11 +2388,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { info.nextInChain = &subgroup_matrix_configs; } +#endif ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2424,6 +2402,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); +#ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { @@ -2439,22 +2418,22 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t } } + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + ctx->subgroup_size = info.subgroupMaxSize; // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, -#ifndef __EMSCRIPTEN__ - wgpu::FeatureName::ImplicitDeviceSynchronization -#endif - }; + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } +#endif #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); From fd6d56b1ee720f1b84b073366886f0e0c1d9e85e Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 3 Nov 2025 21:15:39 -0800 Subject: [PATCH 38/40] Fix test-backend-ops emscripten for f16/quantized types --- tests/test-backend-ops.cpp | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b6657eb7c3186..2a021505a5330 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -115,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(N_THREADS/2, - std::max(1, n_blocks / min_blocks_per_thread)); - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*n_blocks/n_threads; - size_t end = (i+1)*n_blocks/n_threads; - tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); - } - for (auto & t : tasks) { - t.get(); + const size_t n_quant_threads = std::min(std::max(N_THREADS/2, 1), + std::max(1, n_blocks / min_blocks_per_thread)); + + if (n_quant_threads == 1) { + // single-threaded quantization: do all blocks in the current thread + quantize_thread(0, n_blocks); + } else { + std::vector> tasks; + tasks.reserve(n_quant_threads); + for (size_t i = 0; i < n_quant_threads; i++) { + size_t start = i*n_blocks/n_quant_threads; + size_t end = (i+1)*n_blocks/n_quant_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); From 427f1f7b6ec4c45bcd564969456d884c15fd50d2 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 4 Nov 2025 11:02:13 -0800 Subject: [PATCH 39/40] Use emscripten memory64 to support get_memory --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 352f8011c8e75..ccde73afb4b0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,11 @@ option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) + # Use 64-bit memory to support backend_get_memory queries + # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -sMEMORY64=1") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -sMEMORY64=1") + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) option(LLAMA_BUILD_HTML "llama: build HTML file" ON) if (LLAMA_BUILD_HTML) From cbc830984f2a248645aa3ebcfc4601966e1f45de Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 11 Nov 2025 14:31:33 -0800 Subject: [PATCH 40/40] Add build flags and try ci --- .github/workflows/build.yml | 40 +++++++++++++++++++++++++++++ CMakeLists.txt | 5 ++-- ggml/src/ggml-webgpu/CMakeLists.txt | 2 -- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 36084c55078ef..81d57b039d406 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -547,6 +547,46 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 + ubuntu-24-wasm-webgpu: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-latest-wasm-webgpu + evict-old-files: 1d + + - name: Install Emscripten + run: | + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + + - name: Fetch emdawnwebgpu + run: | + DAWN_TAG="v20251027.212519" + EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip" + echo "Downloading ${EMDAWN_PKG}" + curl -L -o emdawn.zip \ + "https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}" + unzip emdawn.zip + + - name: Build WASM WebGPU + run: | + source emsdk/emsdk_env.sh + emcmake cmake -B build-wasm \ + -DGGML_WEBGPU=ON \ + -DLLAMA_CURL=OFF \ + -DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg + + cmake --build build-wasm --target test-backend-ops -j $(nproc) + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.1.2 diff --git a/CMakeLists.txt b/CMakeLists.txt index ccde73afb4b0c..1c69a865b93b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,8 +38,9 @@ if (EMSCRIPTEN) # Use 64-bit memory to support backend_get_memory queries # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -sMEMORY64=1") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -sMEMORY64=1") + add_compile_options("-sMEMORY64=1") + add_link_options("-sMEMORY64=1") + add_link_options("-sALLOW_MEMORY_GROWTH=1") option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) option(LLAMA_BUILD_HTML "llama: build HTML file" ON) diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index a19cf2e4581f9..3ccce58aa39ec 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -56,8 +56,6 @@ if(EMSCRIPTEN) target_compile_options(ggml-webgpu PRIVATE "-fexceptions") target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") endif() - - set(DawnWebGPU_TARGET webgpu_cpp) else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn)