From 469274f99c2128d2fd606b7ab9642c79714c56eb Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sat, 21 Sep 2024 16:14:25 +0100 Subject: [PATCH 1/4] [mlir][tensor] Add new helper hooks to RelayoutOp Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims` and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp an UnPackOp inherit from). This improves code re-use and also clarifies the meaning of "outer dims" and "tiled outer dims". --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 19 ++++++++++++++- .../Dialect/Linalg/Transforms/Transforms.cpp | 23 ++++++++----------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 22 ++++++++++++++++++ 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index cafc3d91fd1e9..9fee75c6a2ca3 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ } //===----------------------------------------------------------------------===// -// PackOp +// RelayoutOp //===----------------------------------------------------------------------===// class Tensor_RelayoutOp traits = []> : @@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp traits = []> : /// a sentinel `kDynamic` is introduced at that position in /// the returned vector. SmallVector getStaticTiles(); + + /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading + /// dims excluding the trailing dims corresponding to `innerTiles`. Note + /// that this will include both tiled and non-tiled dimensions. + ArrayRef getAllOuterDims() { + ShapedType inputType = getSourceType(); + int64_t inputRank = inputType.getRank(); + return getDestType().getShape().take_front(inputRank); + } + + /// Similar to `getAllOuterDims`, but only retrieve the outer dims that + /// have been tiled. + SmallVector getTiledOuterDims(); }]; let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ AttrSizedOperandSegments]> { let summary = "tensor pack operation"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 77f0ea9d2236e..e0dea8e78d55c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, return input; } + assert(llvm::all_of(packOp.getAllOuterDims(), + [](int64_t val) { return val == 1; }) && + "some outer dims are != 1"); + Location loc = packOp.getLoc(); ShapedType inputType = packOp.getSourceType(); int64_t inputRank = inputType.getRank(); - assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), - [](int64_t val) { return val == 1; })); SmallVector paddedShape; DenseMap tileAndPosMapping = @@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( // TODO: support the case that outer dimensions are not all 1s. A // tensor.expand_shape will be generated in this case. - auto innerDimsPos = packOp.getInnerDimsPos(); - int64_t srcRank = packOp.getSourceRank(); - auto destShape = packOp.getDestType().getShape(); - if (llvm::any_of(innerDimsPos, [destShape](int64_t index) { - return destShape[index] != 1; - })) { + if (llvm::any_of(packOp.getTiledOuterDims(), + [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "require the tiled outer dimensions of the result are all 1s"); } @@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( packOp.getDimAndTileMapping(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); + int64_t srcRank = packOp.getSourceRank(); SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); SmallVector readSizes; @@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( loc, readType, input, readOffsets, readSizes, readStrides); // 2. Transpose the tile to match the inner tile order. - SmallVector perm = getPackUnpackRankReducedPerm( - inputShape, innerDimsPos, packOp.getOuterDimsPerm()); + inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); @@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); ArrayRef innerDimsPos = unpackOp.getInnerDimsPos(); - if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) { - return srcShape[index] != 1; - })) { + if (llvm::any_of(unpackOp.getTiledOuterDims(), + [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( unpackOp, "require the tiled outer dimensions of the result are all 1s"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 47f540e092e99..bc7deb1614d18 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3987,6 +3987,17 @@ SmallVector PackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +SmallVector PackOp::getTiledOuterDims() { + auto innerDimsPos = getInnerDimsPos(); + auto destShape = getDestType().getShape(); + SmallVector res; + + for (auto index : innerDimsPos) + res.push_back(destShape[index]); + + return res; +} + bool PackOp::requirePaddingValue(ArrayRef inputShape, ArrayRef innerDimsPos, ArrayRef outputShape, @@ -4411,6 +4422,17 @@ SmallVector UnPackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +SmallVector UnPackOp::getTiledOuterDims() { + auto innerDimsPos = getInnerDimsPos(); + auto destShape = getSourceType().getShape(); + SmallVector res; + + for (auto index : innerDimsPos) + res.push_back(destShape[index]); + + return res; +} + LogicalResult UnPackOp::verify() { return commonVerifierPackAndUnPackOp(*this); } From d6bc07fa1f7c67f4873df0986aedc28fd3f26f1a Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 23 Sep 2024 14:58:20 +0100 Subject: [PATCH 2/4] fixup! [mlir][tensor] Add new helper hooks to RelayoutOp Remove empty space --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 9fee75c6a2ca3..7b57b503ea56f 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1861,7 +1861,7 @@ class Tensor_RelayoutOp traits = []> : return getDestType().getShape().take_front(inputRank); } - /// Similar to `getAllOuterDims`, but only retrieve the outer dims that + /// Similar to `getAllOuterDims`, but only retrieve the outer dims that /// have been tiled. SmallVector getTiledOuterDims(); }]; From 534e096b789c78410ad1e15f777e06d4e6a59d4c Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 23 Sep 2024 18:21:08 +0100 Subject: [PATCH 3/4] fixup! fixup! [mlir][tensor] Add new helper hooks to RelayoutOp Add comments, specialize getAllOuterDims --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 13 ++++++------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ++++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 7b57b503ea56f..3170115883e2b 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1854,15 +1854,14 @@ class Tensor_RelayoutOp traits = []> : /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading /// dims excluding the trailing dims corresponding to `innerTiles`. Note - /// that this will include both tiled and non-tiled dimensions. - ArrayRef getAllOuterDims() { - ShapedType inputType = getSourceType(); - int64_t inputRank = inputType.getRank(); - return getDestType().getShape().take_front(inputRank); - } + /// that this will include both tiled and non-tiled dimensions. The order + /// of the output dimensions is consistent with the shape of the packed + /// tensor. + ArrayRef getAllOuterDims(); /// Similar to `getAllOuterDims`, but only retrieve the outer dims that - /// have been tiled. + /// have been tiled. Also, the order of the output dimensions is consistent + /// with `inner_dims_pos` rather than the packed tensor. SmallVector getTiledOuterDims(); }]; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index bc7deb1614d18..d0ddde96b0b23 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3987,6 +3987,12 @@ SmallVector PackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +ArrayRef PackOp::getAllOuterDims() { + ShapedType inputType = getSourceType(); + int64_t inputRank = inputType.getRank(); + return getDestType().getShape().take_front(inputRank); +} + SmallVector PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); auto destShape = getDestType().getShape(); @@ -4422,6 +4428,12 @@ SmallVector UnPackOp::getStaticTiles() { return getStaticTilesImpl(*this); } +ArrayRef UnPackOp::getAllOuterDims() { + ShapedType destType = getDestType(); + int64_t destRank = destType.getRank(); + return getSourceType().getShape().take_front(destRank); +} + SmallVector UnPackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); auto destShape = getSourceType().getShape(); From 7fd53ba47384464795aa116d67bebffbbbf0a034 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 24 Sep 2024 13:07:31 +0100 Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][tensor] Add new helper hooks to RelayoutOp Use Adam's suggestion for var names --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index d0ddde96b0b23..1ac96756e22b5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3995,11 +3995,11 @@ ArrayRef PackOp::getAllOuterDims() { SmallVector PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto destShape = getDestType().getShape(); + auto packedShape = getDestType().getShape(); SmallVector res; for (auto index : innerDimsPos) - res.push_back(destShape[index]); + res.push_back(packedShape[index]); return res; } @@ -4436,11 +4436,11 @@ ArrayRef UnPackOp::getAllOuterDims() { SmallVector UnPackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto destShape = getSourceType().getShape(); + auto packedShape = getSourceType().getShape(); SmallVector res; for (auto index : innerDimsPos) - res.push_back(destShape[index]); + res.push_back(packedShape[index]); return res; }