Skip to content

Commit 00092f9

Browse files
[MLIR] [Vector] Added canonicalizer for folding from_elements + transpose (#161841)
## Description Adds a new canonicalizer that folds `vector.from_elements(vector.transpose))` => `vector.from_elements`. This canonicalization reorders the input elements for `vector.from_elements`, adjusts the output shape to match the effect of the transpose op and eliminating its need. ## Testing Added a 2D vector lit test that verifies the working of the rewrite. --------- Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent f188c97 commit 00092f9

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6835,6 +6835,73 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
68356835
}
68366836
};
68376837

6838+
/// Folds transpose(from_elements(...)) into a new from_elements with permuted
6839+
/// operands matching the transposed shape.
6840+
///
6841+
/// Example:
6842+
///
6843+
/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
6844+
/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
6845+
/// vector<3x2xi32>
6846+
///
6847+
/// becomes ->
6848+
///
6849+
/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
6850+
/// vector<3x2xi32>
6851+
///
6852+
class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
6853+
public:
6854+
using Base::Base;
6855+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6856+
PatternRewriter &rewriter) const override {
6857+
auto fromElementsOp =
6858+
transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6859+
if (!fromElementsOp)
6860+
return failure();
6861+
6862+
VectorType srcTy = fromElementsOp.getDest().getType();
6863+
VectorType dstTy = transposeOp.getType();
6864+
6865+
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6866+
int64_t rank = srcTy.getRank();
6867+
6868+
// Build inverse permutation to map destination indices back to source.
6869+
SmallVector<int64_t> inversePerm(rank, 0);
6870+
for (int64_t i = 0; i < rank; ++i)
6871+
inversePerm[permutation[i]] = i;
6872+
6873+
ArrayRef<int64_t> srcShape = srcTy.getShape();
6874+
ArrayRef<int64_t> dstShape = dstTy.getShape();
6875+
SmallVector<int64_t> srcIdx(rank, 0);
6876+
SmallVector<int64_t> dstIdx(rank, 0);
6877+
SmallVector<int64_t> srcStrides = computeStrides(srcShape);
6878+
SmallVector<int64_t> dstStrides = computeStrides(dstShape);
6879+
6880+
auto elementsOld = fromElementsOp.getElements();
6881+
SmallVector<Value> elementsNew;
6882+
int64_t dstNumElements = dstTy.getNumElements();
6883+
elementsNew.reserve(dstNumElements);
6884+
6885+
// For each element in destination row-major order, pick the corresponding
6886+
// source element.
6887+
for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
6888+
// Pick the destination element index.
6889+
dstIdx = delinearize(linearIdx, dstStrides);
6890+
// Map the destination element index to the source element index.
6891+
for (int64_t j = 0; j < rank; ++j)
6892+
srcIdx[j] = dstIdx[inversePerm[j]];
6893+
// Linearize the source element index.
6894+
int64_t srcLin = linearize(srcIdx, srcStrides);
6895+
// Add the source element to the new elements.
6896+
elementsNew.push_back(elementsOld[srcLin]);
6897+
}
6898+
6899+
rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
6900+
elementsNew);
6901+
return success();
6902+
}
6903+
};
6904+
68386905
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
68396906
/// 'order preserving', where 'order preserving' means the flattened
68406907
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6935,7 +7002,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
69357002
void vector::TransposeOp::getCanonicalizationPatterns(
69367003
RewritePatternSet &results, MLIRContext *context) {
69377004
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6938-
FoldTransposeSplat, FoldTransposeBroadcast>(context);
7005+
FoldTransposeSplat, FoldTransposeFromElements,
7006+
FoldTransposeBroadcast>(context);
69397007
}
69407008

69417009
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3530,6 +3530,62 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
35303530

35313531
// -----
35323532

3533+
// +---------------------------------------------------------------------------
3534+
// Tests for FoldTransposeFromElements
3535+
// +---------------------------------------------------------------------------
3536+
3537+
// CHECK-LABEL: transpose_from_elements_1d
3538+
// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
3539+
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
3540+
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
3541+
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
3542+
return %t : vector<2xi32>
3543+
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32>
3544+
// CHECK-NOT: vector.transpose
3545+
// CHECK: return %[[R]]
3546+
}
3547+
3548+
// CHECK-LABEL: transpose_from_elements_2d
3549+
// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
3550+
func.func @transpose_from_elements_2d(
3551+
%el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
3552+
%el_1_0: i32, %el_1_1: i32, %el_1_2: i32
3553+
) -> vector<3x2xi32> {
3554+
%v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32>
3555+
%t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
3556+
return %t : vector<3x2xi32>
3557+
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32>
3558+
// CHECK-NOT: vector.transpose
3559+
// CHECK: return %[[R]]
3560+
}
3561+
3562+
// CHECK-LABEL: transpose_from_elements_3d
3563+
// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
3564+
func.func @transpose_from_elements_3d(
3565+
%el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
3566+
%el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
3567+
) -> vector<2x2x3xi32> {
3568+
%v = vector.from_elements
3569+
%el_0_0_0, %el_0_0_1,
3570+
%el_0_1_0, %el_0_1_1,
3571+
%el_0_2_0, %el_0_2_1,
3572+
%el_1_0_0, %el_1_0_1,
3573+
%el_1_1_0, %el_1_1_1,
3574+
%el_1_2_0, %el_1_2_1
3575+
: vector<2x3x2xi32>
3576+
%t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
3577+
return %t : vector<2x2x3xi32>
3578+
// CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32>
3579+
// CHECK-NOT: vector.transpose
3580+
// CHECK: return %[[R]]
3581+
}
3582+
3583+
// +---------------------------------------------------------------------------
3584+
// End of Tests for FoldTransposeFromElements
3585+
// +---------------------------------------------------------------------------
3586+
3587+
// -----
3588+
35333589
// Not a DenseElementsAttr, don't fold.
35343590

35353591
// CHECK-LABEL: func @negative_insert_llvm_undef(

0 commit comments

Comments
 (0)