From dab57eb7626b30cacacf8a63e4707457b6788f7e Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 20 May 2024 18:58:54 +0800 Subject: [PATCH 1/6] [mlir][vector] Add ElementwiseToOuterproduct --- .../mlir/Dialect/Vector/IR/VectorOps.h | 4 + .../Vector/TransformOps/VectorTransformOps.td | 11 +++ .../TransformOps/VectorTransformOps.cpp | 5 ++ .../Vector/Transforms/VectorTransforms.cpp | 75 +++++++++++++++++++ .../test/Dialect/Vector/transform-vector.mlir | 38 ++++++++++ 5 files changed, 133 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 4603953cb40fa..ac55433fadb2f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, /// into vector contract for the backends with native support. void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns that fold elementwise op on vectors to the vector +/// dialect. +void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index bc3c16d40520e..e1da09fba73a7 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op]> { + let description = [{ + Collect a set of patterns that fold elementwise op on vectors to the vector + dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyVectorReductionToContractPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 61fd6bd972e3a..6e13749a66415 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( vector::populateFoldArithExtensionPatterns(patterns); } +void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateElementwiseToVectorOpsPatterns(patterns); +} + void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorReductionToContractPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index f29eba90c3ceb..d7ccfc4986068 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern { unsigned maxNumElementsToExtract = 0; }; +/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) +/// into vector.outerproduct(A, B) such as : +/// ```mlir +/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> +/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to +/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to +/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> +///``` +/// Becomes : +///```mlir +/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> +///``` +/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled. +/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> +/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> +/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> + +template +struct ElementwiseToOuterproduct : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MulOpType mulOp, + PatternRewriter &rewriter) const override { + auto VT = llvm::cast(mulOp.getResult().getType()); + if (!VT) + return failure(); + if (VT.getRank() != 2) + return failure(); + + auto canonicalize = [&](Value OperandA, + Value OperandB) -> vector::OuterProductOp { + vector::TransposeOp transposedLhs = + dyn_cast_or_null(OperandA.getDefiningOp()); + if (!transposedLhs) + return vector::OuterProductOp(); + // Fail unless this is a true 2-D matrix transpose. + ArrayRef permutation = transposedLhs.getPermutation(); + if (permutation[0] != 1 || permutation[1] != 0) + return vector::OuterProductOp(); + + // Fail in case it is not a 1-to-2 dimension to broadcast to avoid + // generating shape_casts/broadcasts which do not belong in this pattern. + vector::BroadcastOp broadcastedLhs = dyn_cast( + transposedLhs.getVector().getDefiningOp()); + if (!broadcastedLhs || + !broadcastedLhs.computeBroadcastedUnitDims().empty()) + return vector::OuterProductOp(); + // Avoid broadcast f32 or vector -> ResType + auto srcVT = dyn_cast(broadcastedLhs.getSourceType()); + if (!srcVT || srcVT.getRank() != 1) + return vector::OuterProductOp(); + + vector::BroadcastOp broadcastedRhs = + dyn_cast(OperandB.getDefiningOp()); + if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT) + return vector::OuterProductOp(); + + return rewriter.replaceOpWithNewOp( + mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(), + Value(), vector::CombiningKind::ADD); + }; + Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); + vector::OuterProductOp outerP = canonicalize(a, b); + // Handle commutativity, the transposed op is the outerproduct LHS. + outerP = outerP ? outerP : canonicalize(b, a); + return outerP ? success() : failure(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( @@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( maxNumElementsToExtract, benefit); } +void mlir::vector::populateElementwiseToVectorOpsPatterns( + RewritePatternSet &patterns) { + patterns.add, + ElementwiseToOuterproduct>(patterns.getContext()); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 75b29e22b4d2c..c170486f6ce27 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @ewise_outerproduct +// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, +// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> +// CHECK: return %[[RES]] : vector<[4]x[4]xi32> +func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { + %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> + %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> + return %mul: vector<[4]x[4]xi32> +} + +// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs +// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, +// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32> +// CHECK: return %[[RES]] : vector<16x16xf32> +func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> { + %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32> + %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32> + %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32> + %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32> + return %mul: vector<16x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.elementwise_to_vector + } : !transform.any_op + transform.yield + } +} From 0f454b045f27c27c57601454854c8add7e147fb3 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 30 May 2024 17:30:03 +0800 Subject: [PATCH 2/6] Add support for different sizes rhs/lhs --- .../Vector/Transforms/VectorTransforms.cpp | 23 +++++++++++-------- .../test/Dialect/Vector/transform-vector.mlir | 13 +++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index d7ccfc4986068..a48101699c4f7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1815,6 +1815,18 @@ struct BreakDownVectorReduction final : OpRewritePattern { template struct ElementwiseToOuterproduct : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + // Helper function returning the source of the input broadcast if it matches requirements for an outerproduct pattern. + Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const { + // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating + // shape_casts/broadcasts which does not belong in this pattern. + if (!broadcastOp.computeBroadcastedUnitDims().empty()) + return Value(); + // Avoid broadcast like f32 or vector -> ResType + auto srcVT = dyn_cast(broadcastOp.getSourceType()); + if (!srcVT || srcVT.getRank() != 1) + return Value(); + return broadcastOp.getSource(); + } LogicalResult matchAndRewrite(MulOpType mulOp, PatternRewriter &rewriter) const override { @@ -1835,21 +1847,14 @@ struct ElementwiseToOuterproduct : public OpRewritePattern { if (permutation[0] != 1 || permutation[1] != 0) return vector::OuterProductOp(); - // Fail in case it is not a 1-to-2 dimension to broadcast to avoid - // generating shape_casts/broadcasts which do not belong in this pattern. vector::BroadcastOp broadcastedLhs = dyn_cast( transposedLhs.getVector().getDefiningOp()); - if (!broadcastedLhs || - !broadcastedLhs.computeBroadcastedUnitDims().empty()) - return vector::OuterProductOp(); - // Avoid broadcast f32 or vector -> ResType - auto srcVT = dyn_cast(broadcastedLhs.getSourceType()); - if (!srcVT || srcVT.getRank() != 1) + if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs)) return vector::OuterProductOp(); vector::BroadcastOp broadcastedRhs = dyn_cast(OperandB.getDefiningOp()); - if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT) + if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs)) return vector::OuterProductOp(); return rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index c170486f6ce27..783deb276f3cc 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -121,6 +121,19 @@ func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector< return %mul: vector<16x16xf32> } +// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes +// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, +// CHECK-SAME: %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32> +// CHECK: return %[[RES]] : vector<8x4xf32> +func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> { + %lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32> + %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32> + %rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32> + %mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32> + return %mul: vector<8x4xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op From 9889dc21ea2fcb951dc97b585f24b4831206cb29 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 30 May 2024 17:37:13 +0800 Subject: [PATCH 3/6] fix formattign --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index a48101699c4f7..0bbdffeb5d9a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1815,7 +1815,8 @@ struct BreakDownVectorReduction final : OpRewritePattern { template struct ElementwiseToOuterproduct : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - // Helper function returning the source of the input broadcast if it matches requirements for an outerproduct pattern. + // Helper function returning the source of the input broadcast if it matches + // requirements for an outerproduct pattern. Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const { // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating // shape_casts/broadcasts which does not belong in this pattern. From de63fd6f2f54867e3280c76808f984e4ec2ca17e Mon Sep 17 00:00:00 2001 From: Hugo Trachino Date: Mon, 3 Jun 2024 09:23:19 +0100 Subject: [PATCH 4/6] Update mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp Co-authored-by: Han-Chung Wang --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 0bbdffeb5d9a6..c7874d4506892 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1797,21 +1797,20 @@ struct BreakDownVectorReduction final : OpRewritePattern { /// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) /// into vector.outerproduct(A, B) such as : -/// ```mlir +/// /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to /// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> -///``` +/// /// Becomes : -///```mlir +/// /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> -///``` +/// /// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled. /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> - template struct ElementwiseToOuterproduct : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From aa165d2715046e3eb903d87f0cda95b2e0facc2c Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 3 Jun 2024 19:13:23 +0800 Subject: [PATCH 5/6] fix review comments --- .../Vector/Transforms/VectorTransforms.cpp | 94 ++++++++++--------- .../test/Dialect/Vector/transform-vector.mlir | 37 +++----- 2 files changed, 63 insertions(+), 68 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c7874d4506892..827c789df7f50 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1795,9 +1795,9 @@ struct BreakDownVectorReduction final : OpRewritePattern { unsigned maxNumElementsToExtract = 0; }; -/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) -/// into vector.outerproduct(A, B) such as : -/// +/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, +/// B)`. +/// Example: /// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> /// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to /// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to @@ -1807,65 +1807,72 @@ struct BreakDownVectorReduction final : OpRewritePattern { /// /// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> /// -/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled. +/// Supports only 1D-to-2D broadcasts. The following cases are not supported. /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> template -struct ElementwiseToOuterproduct : public OpRewritePattern { +struct FoldArithToVectorOuterProduct : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - // Helper function returning the source of the input broadcast if it matches - // requirements for an outerproduct pattern. - Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const { + // Returns whether a vector.broadcast matches requirements for an outerproduct + // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. + bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating // shape_casts/broadcasts which does not belong in this pattern. if (!broadcastOp.computeBroadcastedUnitDims().empty()) - return Value(); + return false; // Avoid broadcast like f32 or vector -> ResType - auto srcVT = dyn_cast(broadcastOp.getSourceType()); - if (!srcVT || srcVT.getRank() != 1) - return Value(); - return broadcastOp.getSource(); + auto srcType = dyn_cast(broadcastOp.getSourceType()); + if (!srcType || srcType.getRank() == 2) + return false; + return true; } LogicalResult matchAndRewrite(MulOpType mulOp, PatternRewriter &rewriter) const override { - auto VT = llvm::cast(mulOp.getResult().getType()); - if (!VT) + auto resType = llvm::cast(mulOp.getResult().getType()); + if (!resType) return failure(); - if (VT.getRank() != 2) + if (resType.getRank() != 2) return failure(); - - auto canonicalize = [&](Value OperandA, - Value OperandB) -> vector::OuterProductOp { + /// If operandA can be written as tr(broadcast(A)) and operandB as + /// broadcast(B) where broadcasts are 1D-to-2D, create and return + /// vector.outerproduct(A, B). Returns failure() otherwise. + auto matchOuterProduct = + [&](Value operandA, + Value operandB) -> FailureOr { vector::TransposeOp transposedLhs = - dyn_cast_or_null(OperandA.getDefiningOp()); + dyn_cast_or_null(operandA.getDefiningOp()); if (!transposedLhs) - return vector::OuterProductOp(); + return failure(); // Fail unless this is a true 2-D matrix transpose. ArrayRef permutation = transposedLhs.getPermutation(); - if (permutation[0] != 1 || permutation[1] != 0) - return vector::OuterProductOp(); - - vector::BroadcastOp broadcastedLhs = dyn_cast( - transposedLhs.getVector().getDefiningOp()); - if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs)) - return vector::OuterProductOp(); - - vector::BroadcastOp broadcastedRhs = - dyn_cast(OperandB.getDefiningOp()); - if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs)) - return vector::OuterProductOp(); - - return rewriter.replaceOpWithNewOp( - mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(), - Value(), vector::CombiningKind::ADD); + if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) + return failure(); + + auto broadcastedLhs = + transposedLhs.getVector().getDefiningOp(); + if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) + return failure(); + + auto broadcastedRhs = operandB.getDefiningOp(); + if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) + return failure(); + + return rewriter.create( + mulOp->getLoc(), resType, broadcastedLhs.getSource(), + broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); }; - Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); - vector::OuterProductOp outerP = canonicalize(a, b); + + Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); + auto maybeOuterP = matchOuterProduct(lhs, rhs); // Handle commutativity, the transposed op is the outerproduct LHS. - outerP = outerP ? outerP : canonicalize(b, a); - return outerP ? success() : failure(); + if (failed(maybeOuterP)) + maybeOuterP = matchOuterProduct(rhs, lhs); + if (failed(maybeOuterP)) + return failure(); + rewriter.replaceOp(mulOp, maybeOuterP->getResult()); + return success(); } }; @@ -1958,8 +1965,9 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( void mlir::vector::populateElementwiseToVectorOpsPatterns( RewritePatternSet &patterns) { - patterns.add, - ElementwiseToOuterproduct>(patterns.getContext()); + patterns.add, + FoldArithToVectorOuterProduct>( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 783deb276f3cc..4b38db79bff3e 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -95,12 +95,12 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func.func @ewise_outerproduct +// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32 // CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, // CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { // CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> // CHECK: return %[[RES]] : vector<[4]x[4]xi32> -func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { +func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> @@ -108,30 +108,17 @@ func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> v return %mul: vector<[4]x[4]xi32> } -// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs +// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32 // CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, -// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> { -// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32> -// CHECK: return %[[RES]] : vector<16x16xf32> -func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> { - %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32> - %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32> - %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32> - %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32> - return %mul: vector<16x16xf32> -} - -// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes -// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, -// CHECK-SAME: %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> { -// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32> -// CHECK: return %[[RES]] : vector<8x4xf32> -func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> { - %lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32> - %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32> - %rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32> - %mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32> - return %mul: vector<8x4xf32> +// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32> +// CHECK: return %[[RES]] : vector<8x16xf32> +func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> { + %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32> + %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32> + %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32> + %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32> + return %mul: vector<8x16xf32> } module attributes {transform.with_named_sequence} { From ec3cd832d96f535bca33028dc5dba516f3577bd6 Mon Sep 17 00:00:00 2001 From: Hugo Date: Fri, 21 Jun 2024 15:51:30 +0800 Subject: [PATCH 6/6] fixup : coding style improvements (nfc) --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 827c789df7f50..1d124261d8eff 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1823,9 +1823,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { return false; // Avoid broadcast like f32 or vector -> ResType auto srcType = dyn_cast(broadcastOp.getSourceType()); - if (!srcType || srcType.getRank() == 2) - return false; - return true; + return srcType && srcType.getRank() != 2; } LogicalResult matchAndRewrite(MulOpType mulOp, @@ -1841,8 +1839,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { auto matchOuterProduct = [&](Value operandA, Value operandB) -> FailureOr { - vector::TransposeOp transposedLhs = - dyn_cast_or_null(operandA.getDefiningOp()); + auto transposedLhs = operandA.getDefiningOp(); if (!transposedLhs) return failure(); // Fail unless this is a true 2-D matrix transpose.