Skip to content

Commit 4d7ea22

Browse files
committed
[Matrix][IR] Don't crash when verifying strides with more than 64 bits
1 parent a42546e commit 4d7ea22

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

llvm/lib/IR/Verifier.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6479,9 +6479,17 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
64796479
NumRows->getZExtValue() * NumColumns->getZExtValue(),
64806480
"Result of a matrix operation does not fit in the returned vector!");
64816481

6482-
if (Stride)
6483-
Check(Stride->getZExtValue() >= NumRows->getZExtValue(),
6482+
if (Stride) {
6483+
// Stride can occupy an arbitrary bit-width, while rows and columns are
6484+
// always 32-bit, so zero extend to the largest common bit-width to
6485+
// compare.
6486+
APInt StrideVal = Stride->getValue();
6487+
APInt NumRowsVal = NumRows->getValue();
6488+
unsigned BitWidth =
6489+
std::max(StrideVal.getBitWidth(), NumRowsVal.getBitWidth());
6490+
Check(StrideVal.zext(BitWidth).uge(NumRowsVal.zext(BitWidth)),
64846491
"Stride must be greater or equal than the number of rows!", IF);
6492+
}
64856493

64866494
break;
64876495
}

llvm/test/Verifier/matrix-intrinsics.ll

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
; RUN: not llvm-as < %s -o /dev/null 2>&1 | FileCheck %s
1+
; RUN: not opt -S %s 2>&1 | FileCheck %s
22

33
define <4 x float> @transpose(<4 x float> %m, i32 %arg) {
4-
; CHECK: assembly parsed, but does not verify as correct!
5-
; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
4+
; CHECK: Result of a matrix operation does not fit in the returned vector!
65
; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
76
; CHECK-NEXT: Result of a matrix operation does not fit in the returned vector!
87
; CHECK-NEXT: immarg operand has non-immediate parameter
@@ -118,16 +117,43 @@ define void @column.major_store_stride_too_small(ptr %m, i64 %arg) {
118117
ret void
119118
}
120119

120+
; This test ensures that verifier correctly handles very wide and very narrows
121+
; strides.
122+
123+
define <4 x float> @column.major_load_stride_i8(ptr %m, i32 %arg) {
124+
%result.1 = call <4 x float> @llvm.matrix.column.major.load.v4f32.i128(ptr %m, i8 16, i1 false, i32 2, i32 2)
125+
ret <4 x float> %result.1
126+
}
127+
128+
define <4 x float> @column.major_load_stride_i128(ptr %m, i32 %arg) {
129+
%result.1 = call <4 x float> @llvm.matrix.column.major.load.v4f32.i128(ptr %m, i128 u0x10000000000000000, i1 false, i32 2, i32 2)
130+
ret <4 x float> %result.1
131+
}
132+
133+
define void @column.major_store_stride_i8(ptr %m, i64 %arg) {
134+
call void @llvm.matrix.column.major.store.v4f32.i128(<4 x float> zeroinitializer, ptr %m, i8 16, i1 false, i32 2, i32 2)
135+
ret void
136+
}
137+
138+
define void @column.major_store_stride_i128(ptr %m, i64 %arg) {
139+
call void @llvm.matrix.column.major.store.v4f32.i128(<4 x float> zeroinitializer, ptr %m, i128 u0x10000000000000000, i1 false, i32 2, i32 2)
140+
ret void
141+
}
142+
121143
declare <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr, i64, i1, i32, i32)
122144
declare <4 x float> @llvm.matrix.column.major.load.v4f32.p0(ptr, i64, i1, i32, i32)
123145
declare <4 x float> @llvm.matrix.column.major.load.v4f32.i64(ptr, i64, i1, i32, i32)
124146
declare <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr, i64, i1, i32, i32)
147+
declare <6 x float> @llvm.matrix.column.major.load.v6f32.i8(ptr, i8, i1, i32, i32)
148+
declare <6 x float> @llvm.matrix.column.major.load.v6f32.i128(ptr, i64, i1, i32, i32)
125149

126150
declare void @llvm.matrix.column.major.store.v4f32.i64(<4 x float>, ptr, i64, i1, i32, i32)
127151
declare void @llvm.matrix.column.major.store.v6f32.i64(<6 x float>, ptr, i64, i1, i32, i32)
128152
declare void @llvm.matrix.column.major.store.v4i32.vi32(<4 x i32>, ptr, i64, i1, i32, i32)
129153
declare void @llvm.matrix.column.major.store.v4f32.p0(<4 x float>, ptr, i64, i1, i32, i32)
130154
declare void @llvm.matrix.column.major.store.v4p0.i64(<4 x ptr>, ptr, i64, i1, i32, i32)
155+
declare void @llvm.matrix.column.major.store.v4p0.i8(<4 x ptr>, ptr, i8, i1, i32, i32)
156+
declare void @llvm.matrix.column.major.store.v4p0.i128(<4 x ptr>, ptr, i64, i1, i32, i32)
131157

132158
declare <4 x i32> @llvm.matrix.transpose.v4i32.v4f32(<4 x float>, i32, i32)
133159
declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32)

0 commit comments

Comments
 (0)