@@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3737struct AssumeAlignmentOpInterface
3838 : public RuntimeVerifiableOpInterface::ExternalModel<
3939 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
41- Location loc) const {
40+ void
41+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
42+ function_ref<std::string(Operation *, StringRef)>
43+ generateErrorMessage) const {
4244 auto assumeOp = cast<AssumeAlignmentOp>(op);
4345 Value ptr = ExtractAlignedPointerAsIndexOp::create (builder, loc,
4446 assumeOp.getMemref ());
@@ -48,18 +50,20 @@ struct AssumeAlignmentOpInterface
4850 Value isAligned =
4951 arith::CmpIOp::create (builder, loc, arith::CmpIPredicate::eq, rest,
5052 arith::ConstantIndexOp::create (builder, loc, 0 ));
51- cf::AssertOp::create (builder, loc, isAligned,
52- RuntimeVerifiableOpInterface::generateErrorMessage (
53- op, " memref is not aligned to " +
53+ cf::AssertOp::create (
54+ builder, loc, isAligned,
55+ generateErrorMessage ( op, " memref is not aligned to " +
5456 std::to_string (assumeOp.getAlignment ())));
5557 }
5658};
5759
5860struct CastOpInterface
5961 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
6062 CastOp> {
61- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
62- Location loc) const {
63+ void
64+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
65+ function_ref<std::string(Operation *, StringRef)>
66+ generateErrorMessage) const {
6367 auto castOp = cast<CastOp>(op);
6468 auto srcType = cast<BaseMemRefType>(castOp.getSource ().getType ());
6569
@@ -76,8 +80,7 @@ struct CastOpInterface
7680 Value isSameRank = arith::CmpIOp::create (
7781 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
7882 cf::AssertOp::create (builder, loc, isSameRank,
79- RuntimeVerifiableOpInterface::generateErrorMessage (
80- op, " rank mismatch" ));
83+ generateErrorMessage (op, " rank mismatch" ));
8184 }
8285
8386 // Get source offset and strides. We do not have an op to get offsets and
@@ -116,8 +119,8 @@ struct CastOpInterface
116119 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117120 cf::AssertOp::create (
118121 builder, loc, isSameSz,
119- RuntimeVerifiableOpInterface:: generateErrorMessage (
120- op, " size mismatch of dim " + std::to_string (it.index ())));
122+ generateErrorMessage (op, " size mismatch of dim " +
123+ std::to_string (it.index ())));
121124 }
122125
123126 // Get result offset and strides.
@@ -135,8 +138,7 @@ struct CastOpInterface
135138 Value isSameOffset = arith::CmpIOp::create (
136139 builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137140 cf::AssertOp::create (builder, loc, isSameOffset,
138- RuntimeVerifiableOpInterface::generateErrorMessage (
139- op, " offset mismatch" ));
141+ generateErrorMessage (op, " offset mismatch" ));
140142 }
141143
142144 // Check strides.
@@ -153,17 +155,19 @@ struct CastOpInterface
153155 builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154156 cf::AssertOp::create (
155157 builder, loc, isSameStride,
156- RuntimeVerifiableOpInterface:: generateErrorMessage (
157- op, " stride mismatch of dim " + std::to_string (it.index ())));
158+ generateErrorMessage (op, " stride mismatch of dim " +
159+ std::to_string (it.index ())));
158160 }
159161 }
160162};
161163
162164struct CopyOpInterface
163165 : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164166 CopyOp> {
165- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
166- Location loc) const {
167+ void
168+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
169+ function_ref<std::string(Operation *, StringRef)>
170+ generateErrorMessage) const {
167171 auto copyOp = cast<CopyOp>(op);
168172 BaseMemRefType sourceType = copyOp.getSource ().getType ();
169173 BaseMemRefType targetType = copyOp.getTarget ().getType ();
@@ -193,9 +197,9 @@ struct CopyOpInterface
193197 Value targetDim = getDimSize (copyOp.getTarget (), rankedTargetType, i);
194198 Value sameDimSize = arith::CmpIOp::create (
195199 builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196- cf::AssertOp::create (builder, loc, sameDimSize,
197- RuntimeVerifiableOpInterface::generateErrorMessage (
198- op, " size of " + std::to_string (i) +
200+ cf::AssertOp::create (
201+ builder, loc, sameDimSize,
202+ generateErrorMessage ( op, " size of " + std::to_string (i) +
199203 " -th source/target dim does not match" ));
200204 }
201205 }
@@ -204,16 +208,17 @@ struct CopyOpInterface
204208struct DimOpInterface
205209 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206210 DimOp> {
207- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
208- Location loc) const {
211+ void
212+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
213+ function_ref<std::string(Operation *, StringRef)>
214+ generateErrorMessage) const {
209215 auto dimOp = cast<DimOp>(op);
210216 Value rank = RankOp::create (builder, loc, dimOp.getSource ());
211217 Value zero = arith::ConstantIndexOp::create (builder, loc, 0 );
212218 cf::AssertOp::create (
213219 builder, loc,
214220 generateInBoundsCheck (builder, loc, dimOp.getIndex (), zero, rank),
215- RuntimeVerifiableOpInterface::generateErrorMessage (
216- op, " index is out of bounds" ));
221+ generateErrorMessage (op, " index is out of bounds" ));
217222 }
218223};
219224
@@ -223,8 +228,10 @@ template <typename LoadStoreOp>
223228struct LoadStoreOpInterface
224229 : public RuntimeVerifiableOpInterface::ExternalModel<
225230 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
227- Location loc) const {
231+ void
232+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
233+ function_ref<std::string(Operation *, StringRef)>
234+ generateErrorMessage) const {
228235 auto loadStoreOp = cast<LoadStoreOp>(op);
229236
230237 auto memref = loadStoreOp.getMemref ();
@@ -245,16 +252,17 @@ struct LoadStoreOpInterface
245252 : inBounds;
246253 }
247254 cf::AssertOp::create (builder, loc, assertCond,
248- RuntimeVerifiableOpInterface::generateErrorMessage (
249- op, " out-of-bounds access" ));
255+ generateErrorMessage (op, " out-of-bounds access" ));
250256 }
251257};
252258
253259struct SubViewOpInterface
254260 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255261 SubViewOp> {
256- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
257- Location loc) const {
262+ void
263+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
264+ function_ref<std::string(Operation *, StringRef)>
265+ generateErrorMessage) const {
258266 auto subView = cast<SubViewOp>(op);
259267 MemRefType sourceType = subView.getSource ().getType ();
260268
@@ -277,10 +285,10 @@ struct SubViewOpInterface
277285 Value dimSize = metadataOp.getSizes ()[i];
278286 Value offsetInBounds =
279287 generateInBoundsCheck (builder, loc, offset, zero, dimSize);
280- cf::AssertOp::create (
281- builder, loc, offsetInBounds,
282- RuntimeVerifiableOpInterface::generateErrorMessage (
283- op, " offset " + std::to_string (i) + " is out-of-bounds" ));
288+ cf::AssertOp::create (builder, loc, offsetInBounds,
289+ generateErrorMessage (op, " offset " +
290+ std::to_string (i) +
291+ " is out-of-bounds" ));
284292
285293 // Verify that slice does not run out-of-bounds.
286294 Value sizeMinusOne = arith::SubIOp::create (builder, loc, size, one);
@@ -292,18 +300,20 @@ struct SubViewOpInterface
292300 generateInBoundsCheck (builder, loc, lastPos, zero, dimSize);
293301 cf::AssertOp::create (
294302 builder, loc, lastPosInBounds,
295- RuntimeVerifiableOpInterface:: generateErrorMessage (
296- op, " subview runs out-of-bounds along dimension " +
297- std::to_string (i)));
303+ generateErrorMessage (op,
304+ " subview runs out-of-bounds along dimension " +
305+ std::to_string (i)));
298306 }
299307 }
300308};
301309
302310struct ExpandShapeOpInterface
303311 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304312 ExpandShapeOp> {
305- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
306- Location loc) const {
313+ void
314+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
315+ function_ref<std::string(Operation *, StringRef)>
316+ generateErrorMessage) const {
307317 auto expandShapeOp = cast<ExpandShapeOp>(op);
308318
309319 // Verify that the expanded dim sizes are a product of the collapsed dim
@@ -333,9 +343,9 @@ struct ExpandShapeOpInterface
333343 Value isModZero = arith::CmpIOp::create (
334344 builder, loc, arith::CmpIPredicate::eq, mod,
335345 arith::ConstantIndexOp::create (builder, loc, 0 ));
336- cf::AssertOp::create (builder, loc, isModZero,
337- RuntimeVerifiableOpInterface::generateErrorMessage (
338- op, " static result dims in reassoc group do not "
346+ cf::AssertOp::create (
347+ builder, loc, isModZero,
348+ generateErrorMessage ( op, " static result dims in reassoc group do not "
339349 " divide src dim evenly" ));
340350 }
341351 }
0 commit comments