Skip to content

Commit 3165df5

Browse files
Fix the wgsl code of subgroup-matix-multiplication
1 parent 4ef6361 commit 3165df5

File tree

3 files changed

+135
-80
lines changed

3 files changed

+135
-80
lines changed

examples/matmul/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ run: ./build/$(TARGET)
1616
$(LIBSPEC) && ./build/$(TARGET)
1717

1818
debug: run.cpp
19-
mkdir -p build && $(CXX) $(FLAGS) -g -fsanitize=address -fno-omit-frame-pointer -Wall -o ./build/$(TARGET)
19+
mkdir -p build && $(CXX) $(FLAGS) -g -fsanitize=address -fno-omit-frame-pointer -fasynchronous-unwind-tables -Wall -o ./build/$(TARGET)
2020

2121
run_with_metal_profiler: ./build/$(TARGET)_with_metal_profiler
2222
$(LIBSPEC) && export METAL_CAPTURE_ENABLED=1 && ./build/$(TARGET)_with_metal_profiler

examples/matmul/run.cpp

Lines changed: 128 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <future>
44
#include <random>
55
#include <cstdlib>
6+
#include <exception>
7+
#include <iostream>
68

79
#include "gpu.hpp" // createContext, createTensor, createKernel, dispatchKernel,
810
// wait, resetCommandBuffer, toCPU
@@ -615,64 +617,76 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
615617

616618
inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
617619
const size_t K, const size_t N,
620+
const size_t TM, const size_t TN,
621+
const Shape &workgroupSize = {256, 1, 1},
618622
NumType precision = kf32) {
619623
std::string codeString(shaderTemplate);
620624
replaceAll(codeString, {{"{{precision}}", toString(precision)},
621625
{"{{M}}", toString(M)},
622626
{"{{K}}", toString(K)},
623-
{"{{N}}", toString(N)}});
624-
return {codeString, {256, 1, 1}, precision};
627+
{"{{N}}", toString(N)},
628+
{"{{TM}}", toString(TM)},
629+
{"{{TN}}", toString(TN)}
630+
});
631+
return {codeString, workgroupSize, precision};
625632
}
626633

627-
628-
629634
// ─────────────────────────────────────────────────────────────────────────────
630635
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631636
// and subgroupMatrixMultiplyAccumulate
632637
// ─────────────────────────────────────────────────────────────────────────────
633638
const char* kShaderSubgroupMatrixMultiply = R"(
639+
enable subgroups;
634640
enable chromium_experimental_subgroup_matrix;
635641
636-
@group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637-
@group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638-
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
642+
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643+
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645+
646+
@compute @workgroup_size({{workgroupSize}})
647+
fn main(@builtin(workgroup_id) wg: vec3<u32>) {
639648
640-
// Each workgroup computes one 16x16 tile of C.
641-
@compute @workgroup_size(256, 1, 1)
642-
fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
649+
let rowStart: u32 = wg.x * 8u * {{TM}};
650+
let colStart: u32 = wg.y * 8u * {{TN}};
643651
644-
let tileRow = groupID.y;
645-
let tileCol = groupID.x;
652+
if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
646653
647-
let outRowStart = tileRow * 16u;
648-
let outColStart = tileCol * 16u;
654+
let baseA: u32 = rowStart * {{K}};
655+
let baseB: u32 = colStart;
656+
let cBase: u32 = rowStart * {{N}} + colStart;
649657
650-
if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651-
return;
652-
}
658+
var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
659+
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
653660
654-
var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
661+
// 4x4 accumulators (8x8 each)
662+
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
655663
656-
let kTiles = ({{K}} + 15u) / 16u;
664+
for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665+
workgroupBarrier();
666+
for (var i: u32 = 0; i < {{TM}}; i++) {
667+
Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
668+
}
657669
658-
// Load the first tile and multiply to initialize accumulator
659-
let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660-
let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661-
acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
670+
for (var i: u32 = 0; i < {{TN}}; i++) {
671+
Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
672+
}
662673
663-
// Loop over the rest of the K-dimension
664-
for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665-
let k = kTile * 16u;
666-
let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667-
let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668-
acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
674+
for (var i: u32 = 0; i < {{TM}}; i++) {
675+
for (var j: u32 = 0; j < {{TN}}; j++) {
676+
accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
677+
}
669678
}
679+
}
670680
671-
subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
681+
workgroupBarrier();
682+
for (var i: u32 = 0; i < {{TM}}; i++) {
683+
for (var j: u32 = 0; j < {{TN}}; j++) {
684+
subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
685+
}
686+
}
672687
}
673688
)";
674689

675-
676690
/**
677691
* @brief No-Op shader with matmul bindings for performance testing
678692
*/
@@ -743,26 +757,30 @@ Kernel selectMatmul(Context &ctx, int version,
743757
const Bindings</* input, weights, output */ 3> &bindings,
744758
size_t M, size_t K, size_t N, NumType numtype) {
745759
Kernel kernel;
760+
CompilationInfo info;
746761
if (version == 1) {
747762
Shape wgSize = {256, 1, 1};
748763
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
749764
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
750765
kernel = createKernel(ctx, matmul, bindings,
751-
/*nWorkgroups*/ nWorkgroups);
766+
/*nWorkgroups*/ nWorkgroups,
767+
NoParam{}, &info);
752768
} else if (version == 2) {
753769
Shape wgSize = {16, 16, 1};
754770
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
755771
KernelCode matmul =
756772
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
757773
kernel = createKernel(ctx, matmul, bindings,
758-
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
774+
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize),
775+
NoParam{}, &info);
759776
} else if (version == 3) {
760777
static constexpr size_t tileSize = 16;
761778
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
762779
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
763780
kernel =
764781
createKernel(ctx, matmul, bindings,
765-
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
782+
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}),
783+
NoParam{}, &info);
766784
} else if (version == 4 || version == 6) {
767785
static constexpr size_t BM = 64;
768786
static constexpr size_t BK = 4;
@@ -781,7 +799,8 @@ Kernel selectMatmul(Context &ctx, int version,
781799
numtype,
782800
/*Loop unrolling*/ version == 6 ? true: false);
783801
kernel = createKernel(ctx, matmul, bindings,
784-
/*nWorkgroups*/ nWorkgroups);
802+
/*nWorkgroups*/ nWorkgroups,
803+
NoParam{}, &info);
785804
} else if (version == 5 || version == 7) {
786805
static constexpr size_t BM = 64;
787806
static constexpr size_t BK = 8;
@@ -799,7 +818,8 @@ Kernel selectMatmul(Context &ctx, int version,
799818
numtype,
800819
/*Loop unrolling*/ version == 7 ? true: false);
801820
kernel = createKernel(ctx, matmul, bindings,
802-
/*nWorkgroups*/ nWorkgroups);
821+
/*nWorkgroups*/ nWorkgroups,
822+
NoParam{}, &info);
803823
} else if (version == 8 || version == 10) {
804824
static constexpr size_t BM = 64;
805825
static constexpr size_t BK = 8;
@@ -817,7 +837,8 @@ Kernel selectMatmul(Context &ctx, int version,
817837
numtype,
818838
/*Loop unrolling*/ true);
819839
kernel = createKernel(ctx, matmul, bindings,
820-
/*nWorkgroups*/ nWorkgroups);
840+
/*nWorkgroups*/ nWorkgroups,
841+
NoParam{}, &info);
821842
} else if (version == 9 || version == 11) {
822843
static constexpr size_t BM = 64;
823844
static constexpr size_t BK = 8;
@@ -834,18 +855,37 @@ Kernel selectMatmul(Context &ctx, int version,
834855
/*wgSize*/ wgSize,
835856
numtype);
836857
kernel = createKernel(ctx, matmul, bindings,
837-
/*nWorkgroups*/ nWorkgroups);
858+
/*nWorkgroups*/ nWorkgroups,
859+
NoParam{}, &info);
838860
} else if (version == 12) {
839861
// f32: Subgroup matrix multiply
840-
Shape wgSize = {256, 1, 1}; // One subgroup per workgroup
841-
Shape nWorkgroups = {cdiv(N, 16), cdiv(M, 16), 1};
862+
static constexpr size_t TM = 2;
863+
static constexpr size_t TN = 4;
864+
Shape wgSize = {64, 1, 1}; // One subgroup per workgroup
865+
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN), 1};
842866
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
843867
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
844868
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
845-
KernelCode matmul =
846-
createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, numtype);
847-
kernel = createKernel(ctx, matmul, bindings, nWorkgroups);
869+
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, wgSize, numtype);
870+
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
871+
NoParam{}, &info);
872+
}
873+
874+
if (info.status != WGPUCompilationInfoRequestStatus_Success) {
875+
LOG(kDefLog, kError, "Failed to compile shader");
876+
for (size_t i = 0; i < info.messages.size(); i++) {
877+
LOG(kDefLog, kError, "Line %llu, Pos %llu: %s", info.lineNums[i],
878+
info.linePos[i], info.messages[i].c_str());
879+
}
880+
exit(1);
881+
} else {
882+
LOG(kDefLog, kInfo, "Shader compiled successfully");
883+
for (size_t i = 0; i < info.messages.size(); i++) {
884+
LOG(kDefLog, kInfo, "Line %llu, Pos %llu: %s", info.lineNums[i],
885+
info.linePos[i], info.messages[i].c_str());
886+
}
848887
}
888+
849889
return kernel;
850890
}
851891

@@ -866,36 +906,49 @@ void runTest(int version, size_t M, size_t K, size_t N,
866906
devDescriptor.requiredFeatureCount = 1;
867907
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();
868908

869-
Context ctx;
870-
if (numtype == kf16) {
871-
ctx = createContext(
872-
{}, {},
873-
/*device descriptor, enabling f16 in WGSL*/
874-
{
875-
.requiredFeatureCount = 1,
876-
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
877-
});
878-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
879-
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
880-
exit(1);
909+
WGPUDawnTogglesDescriptor toggles = {};
910+
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
911+
const char* enableList[] = {"allow_unsafe_apis"};
912+
toggles.enabledToggles = enableList;
913+
toggles.enabledToggleCount = 1;
914+
915+
WGPUDeviceDescriptor devDesc = {};
916+
devDesc.nextInChain = &toggles.chain;
917+
devDesc.requiredFeatureCount = 3,
918+
devDesc.requiredFeatures = std::array{
919+
WGPUFeatureName_ShaderF16,
920+
WGPUFeatureName_Subgroups,
921+
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
922+
}.data();
923+
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
924+
.callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void*, void*) {
925+
LOG(kDefLog, kError, "[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
881926
}
882-
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
883-
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
884-
exit(1);
927+
};
928+
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
929+
.mode = WGPUCallbackMode_AllowSpontaneous,
930+
.callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void*, void*) {
931+
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
885932
}
886-
}
887-
888-
if (numtype == kf32) {
889-
ctx = createContext({}, {}, {});
890-
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
891-
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
892-
LOG(kDefLog, kError, "Failed to create adapter or device");
893-
// stop execution
894-
exit(1);
895-
} else {
896-
LOG(kDefLog, kInfo, "Successfully created adapter and device");
933+
};
934+
935+
Context ctx = createContext({}, {}, devDesc);
936+
937+
WGPULoggingCallbackInfo logCb{
938+
.callback = [](WGPULoggingType type, WGPUStringView msg, void*, void*) {
939+
LOG(kDefLog, kError, "[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
897940
}
898-
}
941+
};
942+
wgpuDeviceSetLoggingCallback(ctx.device, logCb);
943+
944+
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
945+
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
946+
LOG(kDefLog, kError, "Failed to create adapter or device");
947+
// stop execution
948+
exit(1);
949+
} else {
950+
LOG(kDefLog, kInfo, "Successfully created adapter and device");
951+
}
899952

900953
Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
901954
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
@@ -983,14 +1036,15 @@ const std::string versionToStr(int version){
9831036
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
9841037
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization (default)";
9851038
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
986-
case 12: return "f32: Subgroup matrix multiply";
1039+
case 12: return "f16: Subgroup matrix multiply with transpose";
9871040
default: return "Not specified";
9881041
}
9891042
}
9901043

9911044
int main() {
1045+
std::cout << "Starting matmul test..." << std::endl;
9921046
char* version_str = getenv("MATMUL_VERSION");
993-
int version = version_str == NULL ? 12 : atoi(version_str);
1047+
int version = version_str == NULL ? 11 : atoi(version_str);
9941048
// 1 == f32: No-Op
9951049
// 2 == f32: naive matmul
9961050
// 3 == f32: tiling
@@ -1002,8 +1056,8 @@ int main() {
10021056
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
10031057
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
10041058
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005-
// 12 == f32: Subgroup matrix multiply
1006-
bool enableF16 = version == 10 || version ==11;
1059+
// 12 == f16: Subgroup matrix multiply with transpose
1060+
bool enableF16 = version == 10 || version ==11 || version == 12;
10071061
bool transposedInput = version == 9 || version == 11 || version == 12;
10081062
NumType numtype = enableF16 ? kf16 : kf32;
10091063

gpu.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ struct KernelCode {
412412
}
413413
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
414414
replaceAll(data, "{{precision}}", toString(precision));
415-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
415+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
416416
}
417417

418418
/**
@@ -438,7 +438,7 @@ struct KernelCode {
438438
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
439439
replaceAll(data, "{{precision}}", toString(precision));
440440
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
441-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
441+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
442442
}
443443

444444
/**
@@ -464,7 +464,7 @@ struct KernelCode {
464464
replaceAll(data, "{{workgroupSize}}", toString({workgroupSize, 1, 1}));
465465
replaceAll(data, "{{precision}}", toString(precision));
466466
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
467-
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
467+
LOG(kDefLog, kTrace, "Shader code:\n%s", data.c_str());
468468
}
469469

470470
std::string data;
@@ -1309,6 +1309,7 @@ createContextAsync(const WGPUInstanceDescriptor &desc = {},
13091309
ctx.device = wait(ctx, deviceFuture);
13101310
ctx.deviceStatus = WGPURequestDeviceStatus_Success;
13111311
} catch (const std::exception &ex) {
1312+
LOG(kDefLog, kTrace, "requestDeviceAsync: %s", ex.what());
13121313
promise->set_exception(std::make_exception_ptr(ex));
13131314
return promise->get_future();
13141315
}
@@ -1594,7 +1595,7 @@ inline void bufferMapCallback(WGPUMapAsyncStatus status, WGPUStringView message,
15941595
* and a promise to signal completion.
15951596
* @param userdata2 Unused.
15961597
*/
1597-
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status,
1598+
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
15981599
void *userdata1, void * /*userdata2*/) {
15991600
const CallbackData *cbData = static_cast<CallbackData *>(userdata1);
16001601
// Ensure the queue work finished successfully.
@@ -2837,7 +2838,7 @@ Kernel createKernel(Context &ctx, const KernelCode &code,
28372838
* when the work is done.
28382839
* @param userdata2 Unused.
28392840
*/
2840-
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status,
2841+
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
28412842
void *userdata1, void * /*userdata2*/) {
28422843
// Cast the userdata pointer back to our heap‑allocated promise.
28432844
auto *p = reinterpret_cast<std::promise<void> *>(userdata1);

0 commit comments

Comments
 (0)