3
3
#include < future>
4
4
#include < random>
5
5
#include < cstdlib>
6
+ #include < exception>
7
+ #include < iostream>
6
8
7
9
#include " gpu.hpp" // createContext, createTensor, createKernel, dispatchKernel,
8
10
// wait, resetCommandBuffer, toCPU
@@ -615,64 +617,76 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
615
617
616
618
inline KernelCode createMatmul12 (const char *shaderTemplate, const size_t M,
617
619
const size_t K, const size_t N,
620
+ const size_t TM, const size_t TN,
621
+ const Shape &workgroupSize = {256 , 1 , 1 },
618
622
NumType precision = kf32) {
619
623
std::string codeString (shaderTemplate);
620
624
replaceAll (codeString, {{" {{precision}}" , toString (precision)},
621
625
{" {{M}}" , toString (M)},
622
626
{" {{K}}" , toString (K)},
623
- {" {{N}}" , toString (N)}});
624
- return {codeString, {256 , 1 , 1 }, precision};
627
+ {" {{N}}" , toString (N)},
628
+ {" {{TM}}" , toString (TM)},
629
+ {" {{TN}}" , toString (TN)}
630
+ });
631
+ return {codeString, workgroupSize, precision};
625
632
}
626
633
627
-
628
-
629
634
// ─────────────────────────────────────────────────────────────────────────────
630
635
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
631
636
// and subgroupMatrixMultiplyAccumulate
632
637
// ─────────────────────────────────────────────────────────────────────────────
633
638
const char * kShaderSubgroupMatrixMultiply = R"(
639
+ enable subgroups;
634
640
enable chromium_experimental_subgroup_matrix;
635
641
636
- @group(0) @binding(0) var<storage, read> A: array<{{precision}}>;
637
- @group(0) @binding(1) var<storage, read> B: array<{{precision}}>;
638
- @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
642
+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
643
+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
644
+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
645
+
646
+ @compute @workgroup_size({{workgroupSize}})
647
+ fn main(@builtin(workgroup_id) wg: vec3<u32>) {
639
648
640
- // Each workgroup computes one 16x16 tile of C.
641
- @compute @workgroup_size(256, 1, 1)
642
- fn main(@builtin(workgroup_id) groupID: vec3<u32>) {
649
+ let rowStart: u32 = wg.x * 8u * {{TM}};
650
+ let colStart: u32 = wg.y * 8u * {{TN}};
643
651
644
- let tileRow = groupID.y;
645
- let tileCol = groupID.x;
652
+ if (rowStart >= u32({{M}}) || colStart >= u32({{N}})) { return; }
646
653
647
- let outRowStart = tileRow * 16u;
648
- let outColStart = tileCol * 16u;
654
+ let baseA: u32 = rowStart * {{K}};
655
+ let baseB: u32 = colStart;
656
+ let cBase: u32 = rowStart * {{N}} + colStart;
649
657
650
- if (outRowStart >= {{M}} || outColStart >= {{N}}) {
651
- return;
652
- }
658
+ var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
659
+ var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;
653
660
654
- var acc: subgroup_matrix_result<{{precision}}, 16, 16>;
661
+ // 4x4 accumulators (8x8 each)
662
+ var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;
655
663
656
- let kTiles = ({{K}} + 15u) / 16u;
664
+ for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
665
+ workgroupBarrier();
666
+ for (var i: u32 = 0; i < {{TM}}; i++) {
667
+ Ax[i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + i * 8u*{{K}} + k, false, {{K}});
668
+ }
657
669
658
- // Load the first tile and multiply to initialize accumulator
659
- let a_tile_0 = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}}, true, {{K}});
660
- let b_tile_0 = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, outColStart, true, {{N}});
661
- acc = subgroupMatrixMultiply<{{precision}}>(a_tile_0, b_tile_0);
670
+ for (var i: u32 = 0; i < {{TN}}; i++) {
671
+ Bx[i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k*{{N}} + 8u * i, false, {{N}});
672
+ }
662
673
663
- // Loop over the rest of the K-dimension
664
- for (var kTile: u32 = 1u; kTile < kTiles; kTile = kTile + 1u) {
665
- let k = kTile * 16u;
666
- let a_tile = subgroupMatrixLoad<subgroup_matrix_left<{{precision}}, 16, 16>>(A, outRowStart * {{K}} + k, true, {{K}});
667
- let b_tile = subgroupMatrixLoad<subgroup_matrix_right<{{precision}}, 16, 16>>(B, k * {{N}} + outColStart, true, {{N}});
668
- acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
674
+ for (var i: u32 = 0; i < {{TM}}; i++) {
675
+ for (var j: u32 = 0; j < {{TN}}; j++) {
676
+ accxx[i+j*{{TM}}] = subgroupMatrixMultiplyAccumulate(Ax[i], Bx[j], accxx[i+j*{{TM}}]);
677
+ }
669
678
}
679
+ }
670
680
671
- subgroupMatrixStore(C, outRowStart * {{N}} + outColStart, acc, true, {{N}});
681
+ workgroupBarrier();
682
+ for (var i: u32 = 0; i < {{TM}}; i++) {
683
+ for (var j: u32 = 0; j < {{TN}}; j++) {
684
+ subgroupMatrixStore(&C, cBase + i * 8u * {{N}} + 8u * j, accxx[i+j*{{TM}}], false, {{N}});
685
+ }
686
+ }
672
687
}
673
688
)" ;
674
689
675
-
676
690
/* *
677
691
* @brief No-Op shader with matmul bindings for performance testing
678
692
*/
@@ -743,26 +757,30 @@ Kernel selectMatmul(Context &ctx, int version,
743
757
const Bindings</* input, weights, output */ 3 > &bindings,
744
758
size_t M, size_t K, size_t N, NumType numtype) {
745
759
Kernel kernel;
760
+ CompilationInfo info;
746
761
if (version == 1 ) {
747
762
Shape wgSize = {256 , 1 , 1 };
748
763
Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
749
764
KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
750
765
kernel = createKernel (ctx, matmul, bindings,
751
- /* nWorkgroups*/ nWorkgroups);
766
+ /* nWorkgroups*/ nWorkgroups,
767
+ NoParam{}, &info);
752
768
} else if (version == 2 ) {
753
769
Shape wgSize = {16 , 16 , 1 };
754
770
LOG (kDefLog , kInfo , " wgSize: %s" , toString (wgSize).c_str ());
755
771
KernelCode matmul =
756
772
createMatmul1 (kShaderMatmul1 , M, K, N, /* wgsize*/ wgSize, numtype);
757
773
kernel = createKernel (ctx, matmul, bindings,
758
- /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize));
774
+ /* nWorkgroups*/ cdiv ({M, N, 1 }, wgSize),
775
+ NoParam{}, &info);
759
776
} else if (version == 3 ) {
760
777
static constexpr size_t tileSize = 16 ;
761
778
KernelCode matmul = createMatmul2 (kShaderMatmul2 , M, K, N,
762
779
/* wgSize*/ {tileSize * tileSize, 1 , 1 }, numtype);
763
780
kernel =
764
781
createKernel (ctx, matmul, bindings,
765
- /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }));
782
+ /* nWorkgroups*/ cdiv ({M, N, 1 }, {tileSize, tileSize, 1 }),
783
+ NoParam{}, &info);
766
784
} else if (version == 4 || version == 6 ) {
767
785
static constexpr size_t BM = 64 ;
768
786
static constexpr size_t BK = 4 ;
@@ -781,7 +799,8 @@ Kernel selectMatmul(Context &ctx, int version,
781
799
numtype,
782
800
/* Loop unrolling*/ version == 6 ? true : false );
783
801
kernel = createKernel (ctx, matmul, bindings,
784
- /* nWorkgroups*/ nWorkgroups);
802
+ /* nWorkgroups*/ nWorkgroups,
803
+ NoParam{}, &info);
785
804
} else if (version == 5 || version == 7 ) {
786
805
static constexpr size_t BM = 64 ;
787
806
static constexpr size_t BK = 8 ;
@@ -799,7 +818,8 @@ Kernel selectMatmul(Context &ctx, int version,
799
818
numtype,
800
819
/* Loop unrolling*/ version == 7 ? true : false );
801
820
kernel = createKernel (ctx, matmul, bindings,
802
- /* nWorkgroups*/ nWorkgroups);
821
+ /* nWorkgroups*/ nWorkgroups,
822
+ NoParam{}, &info);
803
823
} else if (version == 8 || version == 10 ) {
804
824
static constexpr size_t BM = 64 ;
805
825
static constexpr size_t BK = 8 ;
@@ -817,7 +837,8 @@ Kernel selectMatmul(Context &ctx, int version,
817
837
numtype,
818
838
/* Loop unrolling*/ true );
819
839
kernel = createKernel (ctx, matmul, bindings,
820
- /* nWorkgroups*/ nWorkgroups);
840
+ /* nWorkgroups*/ nWorkgroups,
841
+ NoParam{}, &info);
821
842
} else if (version == 9 || version == 11 ) {
822
843
static constexpr size_t BM = 64 ;
823
844
static constexpr size_t BK = 8 ;
@@ -834,18 +855,37 @@ Kernel selectMatmul(Context &ctx, int version,
834
855
/* wgSize*/ wgSize,
835
856
numtype);
836
857
kernel = createKernel (ctx, matmul, bindings,
837
- /* nWorkgroups*/ nWorkgroups);
858
+ /* nWorkgroups*/ nWorkgroups,
859
+ NoParam{}, &info);
838
860
} else if (version == 12 ) {
839
861
// f32: Subgroup matrix multiply
840
- Shape wgSize = {256 , 1 , 1 }; // One subgroup per workgroup
841
- Shape nWorkgroups = {cdiv (N, 16 ), cdiv (M, 16 ), 1 };
862
+ static constexpr size_t TM = 2 ;
863
+ static constexpr size_t TN = 4 ;
864
+ Shape wgSize = {64 , 1 , 1 }; // One subgroup per workgroup
865
+ Shape nWorkgroups = {cdiv (M, 8 * TM), cdiv (N, 8 * TN), 1 };
842
866
LOG (kDefLog , kInfo , " M: %zu, K: %zu, N: %zu" , M, K, N);
843
867
LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
844
868
LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
845
- KernelCode matmul =
846
- createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, numtype);
847
- kernel = createKernel (ctx, matmul, bindings, nWorkgroups);
869
+ KernelCode matmul = createMatmul12 (kShaderSubgroupMatrixMultiply , M, K, N, TM, TN, wgSize, numtype);
870
+ kernel = createKernel (ctx, matmul, bindings, nWorkgroups,
871
+ NoParam{}, &info);
872
+ }
873
+
874
+ if (info.status != WGPUCompilationInfoRequestStatus_Success) {
875
+ LOG (kDefLog , kError , " Failed to compile shader" );
876
+ for (size_t i = 0 ; i < info.messages .size (); i++) {
877
+ LOG (kDefLog , kError , " Line %llu, Pos %llu: %s" , info.lineNums [i],
878
+ info.linePos [i], info.messages [i].c_str ());
879
+ }
880
+ exit (1 );
881
+ } else {
882
+ LOG (kDefLog , kInfo , " Shader compiled successfully" );
883
+ for (size_t i = 0 ; i < info.messages .size (); i++) {
884
+ LOG (kDefLog , kInfo , " Line %llu, Pos %llu: %s" , info.lineNums [i],
885
+ info.linePos [i], info.messages [i].c_str ());
886
+ }
848
887
}
888
+
849
889
return kernel;
850
890
}
851
891
@@ -866,36 +906,49 @@ void runTest(int version, size_t M, size_t K, size_t N,
866
906
devDescriptor.requiredFeatureCount = 1 ;
867
907
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ();
868
908
869
- Context ctx;
870
- if (numtype == kf16) {
871
- ctx = createContext (
872
- {}, {},
873
- /* device descriptor, enabling f16 in WGSL*/
874
- {
875
- .requiredFeatureCount = 1 ,
876
- .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ()
877
- });
878
- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
879
- LOG (kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)." );
880
- exit (1 );
909
+ WGPUDawnTogglesDescriptor toggles = {};
910
+ toggles.chain .sType = WGPUSType_DawnTogglesDescriptor;
911
+ const char * enableList[] = {" allow_unsafe_apis" };
912
+ toggles.enabledToggles = enableList;
913
+ toggles.enabledToggleCount = 1 ;
914
+
915
+ WGPUDeviceDescriptor devDesc = {};
916
+ devDesc.nextInChain = &toggles.chain ;
917
+ devDesc.requiredFeatureCount = 3 ,
918
+ devDesc.requiredFeatures = std::array{
919
+ WGPUFeatureName_ShaderF16,
920
+ WGPUFeatureName_Subgroups,
921
+ WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
922
+ }.data ();
923
+ devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
924
+ .callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void *, void *) {
925
+ LOG (kDefLog , kError , " [Uncaptured %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
881
926
}
882
- if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
883
- LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
884
- exit (1 );
927
+ };
928
+ devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
929
+ .mode = WGPUCallbackMode_AllowSpontaneous,
930
+ .callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void *, void *) {
931
+ LOG (kDefLog , kError , " [DeviceLost %d] %.*s\n " , (int )reason, (int )msg.length , msg.data );
885
932
}
886
- }
887
-
888
- if (numtype == kf32) {
889
- ctx = createContext ({}, {}, {});
890
- if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
891
- ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
892
- LOG (kDefLog , kError , " Failed to create adapter or device" );
893
- // stop execution
894
- exit (1 );
895
- } else {
896
- LOG (kDefLog , kInfo , " Successfully created adapter and device" );
933
+ };
934
+
935
+ Context ctx = createContext ({}, {}, devDesc);
936
+
937
+ WGPULoggingCallbackInfo logCb{
938
+ .callback = [](WGPULoggingType type, WGPUStringView msg, void *, void *) {
939
+ LOG (kDefLog , kError , " [WGPU %d] %.*s\n " , (int )type, (int )msg.length , msg.data );
897
940
}
898
- }
941
+ };
942
+ wgpuDeviceSetLoggingCallback (ctx.device , logCb);
943
+
944
+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
945
+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
946
+ LOG (kDefLog , kError , " Failed to create adapter or device" );
947
+ // stop execution
948
+ exit (1 );
949
+ } else {
950
+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
951
+ }
899
952
900
953
Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
901
954
Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -983,14 +1036,15 @@ const std::string versionToStr(int version){
983
1036
case 9 : return " f32: 2D blocktiling with loop unrolling, vectorization and transpose" ;
984
1037
case 10 : return " f16: 2D blocktiling with loop unrolling and vectorization (default)" ;
985
1038
case 11 : return " f16: 2D blocktiling with loop unrolling, vectorization and transpose" ;
986
- case 12 : return " f32 : Subgroup matrix multiply" ;
1039
+ case 12 : return " f16 : Subgroup matrix multiply with transpose " ;
987
1040
default : return " Not specified" ;
988
1041
}
989
1042
}
990
1043
991
1044
int main () {
1045
+ std::cout << " Starting matmul test..." << std::endl;
992
1046
char * version_str = getenv (" MATMUL_VERSION" );
993
- int version = version_str == NULL ? 12 : atoi (version_str);
1047
+ int version = version_str == NULL ? 11 : atoi (version_str);
994
1048
// 1 == f32: No-Op
995
1049
// 2 == f32: naive matmul
996
1050
// 3 == f32: tiling
@@ -1002,8 +1056,8 @@ int main() {
1002
1056
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
1003
1057
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
1004
1058
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
1005
- // 12 == f32 : Subgroup matrix multiply
1006
- bool enableF16 = version == 10 || version ==11 ;
1059
+ // 12 == f16 : Subgroup matrix multiply with transpose
1060
+ bool enableF16 = version == 10 || version ==11 || version == 12 ;
1007
1061
bool transposedInput = version == 9 || version == 11 || version == 12 ;
1008
1062
NumType numtype = enableF16 ? kf16 : kf32;
1009
1063
0 commit comments