diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index cafc3d91fd1e9..3170115883e2b 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ } //===----------------------------------------------------------------------===// -// PackOp +// RelayoutOp //===----------------------------------------------------------------------===// class Tensor_RelayoutOp traits = []> : @@ -1851,11 +1851,27 @@ class Tensor_RelayoutOp traits = []> : /// a sentinel `kDynamic` is introduced at that position in /// the returned vector. SmallVector getStaticTiles(); + + /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading + /// dims excluding the trailing dims corresponding to `innerTiles`. Note + /// that this will include both tiled and non-tiled dimensions. The order + /// of the output dimensions is consistent with the shape of the packed + /// tensor. + ArrayRef getAllOuterDims(); + + /// Similar to `getAllOuterDims`, but only retrieve the outer dims that + /// have been tiled. Also, the order of the output dimensions is consistent + /// with `inner_dims_pos` rather than the packed tensor. + SmallVector getTiledOuterDims(); }]; let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ AttrSizedOperandSegments]> { let summary = "tensor pack operation"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 77f0ea9d2236e..e0dea8e78d55c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, return input; } + assert(llvm::all_of(packOp.getAllOuterDims(), + [](int64_t val) { return val == 1; }) && + "some outer dims are != 1"); + Location loc = packOp.getLoc(); ShapedType inputType = packOp.getSourceType(); int64_t inputRank = inputType.getRank(); - assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), - [](int64_t val) { return val == 1; })); SmallVector paddedShape; DenseMap tileAndPosMapping = @@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( // TODO: support the case that outer dimensions are not all 1s. A // tensor.expand_shape will be generated in this case. - auto innerDimsPos = packOp.getInnerDimsPos(); - int64_t srcRank = packOp.getSourceRank(); - auto destShape = packOp.getDestType().getShape(); - if (llvm::any_of(innerDimsPos, [destShape](int64_t index) { - return destShape[index] != 1; - })) { + if (llvm::any_of(packOp.getTiledOuterDims(), + [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "require the tiled outer dimensions of the result are all 1s"); } @@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( packOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); + int64_t srcRank = packOp.getSourceRank(); SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); SmallVector readSizes; @@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( loc, readType, input, readOffsets, readSizes, readStrides); // 2. Transpose the tile to match the inner tile order. - SmallVector perm = getPackUnpackRankReducedPerm( - inputShape, innerDimsPos, packOp.getOuterDimsPerm()); + inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); @@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); - if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) { - return srcShape[index] != 1; - })) { + if (llvm::any_of(unpackOp.getTiledOuterDims(), + [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( unpackOp, "require the tiled outer dimensions of the result are all 1s"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 47f540e092e99..1ac96756e22b5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3987,6 +3987,23 @@ SmallVector PackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +ArrayRef PackOp::getAllOuterDims() { + ShapedType inputType = getSourceType(); + int64_t inputRank = inputType.getRank(); + return getDestType().getShape().take_front(inputRank); +} + +SmallVector PackOp::getTiledOuterDims() { + auto innerDimsPos = getInnerDimsPos(); + auto packedShape = getDestType().getShape(); + SmallVector res; + + for (auto index : innerDimsPos) + res.push_back(packedShape[index]); + + return res; +} + bool PackOp::requirePaddingValue(ArrayRef inputShape, ArrayRef innerDimsPos, ArrayRef outputShape, @@ -4411,6 +4428,23 @@ SmallVector UnPackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +ArrayRef UnPackOp::getAllOuterDims() { + ShapedType destType = getDestType(); + int64_t destRank = destType.getRank(); + return getSourceType().getShape().take_front(destRank); +} + +SmallVector UnPackOp::getTiledOuterDims() { + auto innerDimsPos = getInnerDimsPos(); + auto packedShape = getSourceType().getShape(); + SmallVector res; + + for (auto index : innerDimsPos) + res.push_back(packedShape[index]); + + return res; +} + LogicalResult UnPackOp::verify() { return commonVerifierPackAndUnPackOp(*this); }