Skip to content

Commit 0f90448

Browse files
committed
[mlir][amdgpu][rocdl] Add gfx1250 wmma ops
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.
1 parent f248010 commit 0f90448

File tree

6 files changed

+306
-32
lines changed

6 files changed

+306
-32
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
912912
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
913913
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
914914
// wmma
915-
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
916-
VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
917-
VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
915+
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
916+
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
917+
VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
918+
VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
918919
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
919920
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
920921
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
992993
Arguments<(ins
993994
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
994995
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
995-
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
996+
ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
996997
WMMAInTypes:$sourceA,
997998
WMMAInTypes:$sourceB,
998999
WMMAOutTypes:$destC,
@@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
10051006
let description = [{
10061007
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
10071008
instructions in the AMDGPU architecture, which perform matrix multiplication.
1008-
Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
1009-
dimensions.
1009+
1010+
On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
1011+
1012+
On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
1013+
all element types, and K=32 for i4 sources.
1014+
1015+
On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
1016+
depending on the element types.
10101017

10111018
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
10121019
(or 16xbf16) vector containing only 8 valid values:
@@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :
10221029

10231030
Example:
10241031
```mlir
1025-
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
1032+
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
1033+
1034+
%1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
1035+
1036+
%2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
1037+
1038+
%3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
10261039
```
10271040
}];
10281041
let assemblyFormat = [{

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10021002
Type elemDestType = destVectorType.getElementType();
10031003

10041004
const uint32_t k = wmma.getK();
1005+
const bool isRDNA3 = chipset.majorVersion == 11;
1006+
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
10051007

10061008
if (k == 16) {
1009+
if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
1010+
return std::nullopt;
1011+
10071012
if (elemSourceType.isF16() && elemDestType.isF32())
10081013
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
10091014
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10191024
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
10201025
}
10211026
}
1022-
if (chipset.majorVersion < 12)
1027+
if (isRDNA3)
10231028
return std::nullopt;
10241029

1030+
using fp8 = Float8E4M3FNType;
1031+
using bf8 = Float8E5M2Type;
1032+
10251033
// gfx12+
10261034
if (k == 16) {
1027-
if (isa<Float8E4M3FNType>(elemSourceType) &&
1028-
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1035+
if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
1036+
return std::nullopt;
1037+
1038+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1039+
elemDestType.isF32())
10291040
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1030-
if (isa<Float8E4M3FNType>(elemSourceType) &&
1031-
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1041+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1042+
elemDestType.isF32())
10321043
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1033-
if (isa<Float8E5M2Type>(elemSourceType) &&
1034-
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1044+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1045+
elemDestType.isF32())
10351046
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1036-
if (isa<Float8E5M2Type>(elemSourceType) &&
1037-
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1047+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1048+
elemDestType.isF32())
10381049
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1050+
10391051
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10401052
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
10411053

10421054
return std::nullopt;
10431055
}
10441056
if (k == 32) {
1045-
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1046-
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1057+
if (isRDNA4) {
1058+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1059+
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1060+
return std::nullopt;
1061+
}
1062+
1063+
// gfx1250
1064+
if (elemSourceType.isF16() && elemDestType.isF32())
1065+
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1066+
if (elemSourceType.isBF16() && elemDestType.isF32())
1067+
return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1068+
if (elemSourceType.isF16() && elemDestType.isF16())
1069+
return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1070+
if (elemSourceType.isBF16() && elemDestType.isBF16())
1071+
return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1072+
1073+
return std::nullopt;
1074+
}
1075+
1076+
if (isRDNA4)
1077+
return std::nullopt;
1078+
1079+
// gfx1250
1080+
if (k == 4) {
1081+
if (elemSourceType.isF32() && elemDestType.isF32())
1082+
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
10471083
return std::nullopt;
10481084
}
10491085

1086+
if (k == 64) {
1087+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1088+
if (elemDestType.isF32())
1089+
return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1090+
if (elemDestType.isF16())
1091+
return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1092+
}
1093+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1094+
if (elemDestType.isF32())
1095+
return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1096+
if (elemDestType.isF16())
1097+
return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1098+
}
1099+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1100+
if (elemDestType.isF32())
1101+
return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1102+
if (elemDestType.isF16())
1103+
return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1104+
}
1105+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1106+
if (elemDestType.isF32())
1107+
return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1108+
if (elemDestType.isF16())
1109+
return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1110+
}
1111+
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1112+
return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1113+
1114+
return std::nullopt;
1115+
}
1116+
1117+
if (k == 128) {
1118+
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1119+
if (elemDestType.isF32())
1120+
return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1121+
if (elemDestType.isF16())
1122+
return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1123+
}
1124+
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1125+
if (elemDestType.isF32())
1126+
return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1127+
if (elemDestType.isF16())
1128+
return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1129+
}
1130+
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1131+
if (elemDestType.isF32())
1132+
return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1133+
if (elemDestType.isF16())
1134+
return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1135+
}
1136+
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1137+
if (elemDestType.isF32())
1138+
return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1139+
if (elemDestType.isF16())
1140+
return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1141+
}
1142+
1143+
return std::nullopt;
1144+
}
10501145
llvm_unreachable("unhandled WMMA case");
10511146
}
10521147

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
399399

400400
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
401401
return emitOpError(
402-
"source element types much match (except for fp8) but have ")
402+
"source element types much match (except for fp8/bf8) but have ")
403403
<< sourceAType << " and " << sourceBType;
404404
}
405405

406-
if (!sourceAElemType.isInteger(4) && getK() != 16) {
407-
return emitOpError("K dimension must be 16 for source element type ")
408-
<< sourceAElemType;
406+
if (isSrcFloat) {
407+
if (getClamp())
408+
return emitOpError("clamp flag is not supported for float types");
409+
if (getUnsignedA() || getUnsignedB())
410+
return emitOpError("unsigned flags are not supported for float types");
409411
}
410412
return success();
411413
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
2+
3+
// CHECK-LABEL: @wmma_k4
4+
func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
5+
// CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
6+
amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
7+
func.return
8+
}
9+
10+
// CHECK-LABEL: @wmma_k32
11+
func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
12+
%arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
13+
// CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
14+
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
15+
16+
// CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
17+
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
18+
19+
// CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
20+
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
21+
22+
// CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
23+
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
24+
25+
func.return
26+
}
27+
28+
// CHECK-LABEL: @wmma_k64
29+
func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
30+
%arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
31+
// CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
32+
amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
33+
34+
// CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
35+
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
36+
37+
// CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
38+
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
39+
40+
// CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
41+
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
42+
43+
// CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
44+
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
45+
46+
// CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
47+
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
48+
49+
// CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
50+
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
51+
52+
// CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
53+
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
54+
55+
// CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
56+
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
57+
58+
func.return
59+
}
60+
61+
// CHECK-LABEL: @wmma_k128
62+
func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
63+
%arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
64+
// CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
65+
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
66+
67+
// CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
68+
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
69+
70+
// CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
71+
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
72+
73+
// CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
74+
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
75+
76+
// CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
77+
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
78+
79+
// CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
80+
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
81+
82+
// CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
83+
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
84+
85+
// CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
86+
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
87+
88+
func.return
89+
}

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
156156

157157
// -----
158158

159-
func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
160-
// expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
161-
%0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
162-
func.return %0 : vector<8xi32>
163-
}
164-
165-
// -----
166-
167159
func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
168160
// expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
169161
%0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
173165
// -----
174166

175167
func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
176-
// expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
168+
// expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
177169
%0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
178170
func.return %0 : vector<8xi32>
179171
}
180172

181173
// -----
182174

183-
// Missinng `resetOffset`
175+
func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
176+
// expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}}
177+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
178+
func.return %0 : vector<8xf32>
179+
}
180+
181+
// -----
182+
183+
func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
184+
// expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
185+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
186+
func.return %0 : vector<8xf32>
187+
}
188+
189+
// -----
190+
191+
func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
192+
// expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
193+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
194+
func.return %0 : vector<8xi32>
195+
}
196+
197+
// -----
198+
199+
func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
200+
// expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
201+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
202+
func.return %0 : vector<8xf32>
203+
}
204+
205+
// -----
206+
207+
func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
208+
// expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
209+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
210+
func.return %0 : vector<8xf32>
211+
}
212+
213+
// -----
214+
215+
func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
216+
// expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
217+
%0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
218+
func.return %0 : vector<8xf32>
219+
}
220+
221+
// -----
222+
223+
// Missing `resetOffset`
184224
func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
185225
// expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
186226
%ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>

0 commit comments

Comments
 (0)