diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 6fd992afbf043..23eab706c856d 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -369,7 +369,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ ``` }]; let arguments = (ins - Arg:$base, + Arg, "the reference to load from", [MemRead]>:$base, Variadic:$indices, Optional:$padding, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout @@ -443,7 +443,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ ``` }]; let arguments = (ins SMETile:$valueToStore, - Arg:$base, + Arg, "the reference to store to", [MemWrite]>:$base, Variadic:$indices, Optional:$mask, ArmSME_TileSliceLayoutAttr:$layout ); diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 6ed29903ea407..9bdafb7d8c501 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -33,20 +33,15 @@ SmallVector getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, Value tileSliceNumElts, Location loc, PatternRewriter &rewriter) { - assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + assert(rank == 2 && "memref has unexpected rank!"); SmallVector outIndices; auto tileSliceOffset = tileSliceIndex; - if (rank == 1) - tileSliceOffset = - rewriter.create(loc, tileSliceOffset, tileSliceNumElts); auto baseIndexPlusTileSliceOffset = rewriter.create(loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); - - if (rank == 2) - outIndices.push_back(indices[1]); + outIndices.push_back(indices[1]); return outIndices; } @@ -60,6 +55,10 @@ FailureOr createLoadStoreForOverTileSlices( makeLoopBody) { PatternRewriter::InsertionGuard guard(rewriter); + // TODO: This case should be captured and rejected by a verifier. + if (memrefIndices.size() != 2) + return rewriter.notifyMatchFailure(loc, "invalid number of indices"); + auto minTileSlices = rewriter.create( loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 700b2412ff7a7..c015fe7cf1641 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -111,6 +111,15 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref, %pad : f64 return } +// ----- + +func.func @arm_sme_tile_load__bad_memref_rank(%src : memref, %pad : f64) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op operand #0 must be 2D memref of any type values, but got 'memref'}} + %tile = arm_sme.tile_load %src[%c0], %pad, : memref, vector<[2]x[2]xf64> + return +} + //===----------------------------------------------------------------------===// // arm_sme.load_tile_slice //===----------------------------------------------------------------------===// @@ -138,6 +147,15 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask return } +// ----- + +func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op operand #1 must be 2D memref of any type values, but got 'memref'}} + arm_sme.tile_store %tile, %dest[%c0] : memref, vector<[16]x[16]xi8> + return +} + //===----------------------------------------------------------------------===// // arm_sme.store_tile_slice //===----------------------------------------------------------------------===//