From 69100604c8883af26d64eb47495aa4401fdef86f Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Fri, 14 Feb 2025 15:02:11 +0000 Subject: [PATCH 1/9] Explicitly speciy all vector transform options on ConvertVectorToLLVMPass Refactor ConvertVectorToLLVMPass options --- mlir/include/mlir/Conversion/Passes.td | 32 ++++++++-- .../Vector/Transforms/LoweringPatterns.h | 11 ++-- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 4 +- .../SPIRV/Transforms/SPIRVConversion.cpp | 5 +- .../TransformOps/VectorTransformOps.cpp | 9 +-- .../Vector/Transforms/LowerVectorContract.cpp | 64 +++++++++---------- .../Transforms/LowerVectorTranspose.cpp | 28 ++++---- .../VectorToLLVM/test-serialisable.mlir | 16 +++++ 8 files changed, 101 insertions(+), 68 deletions(-) create mode 100644 mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index cccdf0a8518bf..606a38f7d98eb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -10,7 +10,7 @@ #define MLIR_CONVERSION_PASSES include "mlir/Pass/PassBase.td" - +include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" //===----------------------------------------------------------------------===// // ToLLVM @@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " "dialect.">, - Option<"vectorTransformsOptions", "vector-transform-options", - "vector::VectorTransformsOptions", - /*default=*/"vector::VectorTransformsOptions()", - "Options to lower some operations like contractions and transposes.">, + Option<"vectorContractLowering", "vector-contract-lowering", + "vector::VectorContractLowering", + /*default=*/"vector::VectorContractLowering::Dot", + VectorContractLoweringAttr.summary, [{::llvm::cl::values( + clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot", + "Progressively lower to finer grained `vector.contract` and dot-products. (default)"), + clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul", + "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."), + clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct", + "Lower to `vector.outerproduct`."), + clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith", + "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.") + )}]>, + Option<"vectorTransposeLowering", "vector-transpose-lowering", + "vector::VectorTransposeLowering", + /*default=*/"vector::VectorTransposeLowering::EltWise", + VectorTransposeLoweringAttr.summary, [{::llvm::cl::values( + clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise", + "Lower transpose into element-wise extract and inserts (default)"), + clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat", + "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"), + clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d", + "Lower 2-D transpose to `vector.shuffle` on 1-D vector."), + clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16", + "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.") + )}]>, ]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 6aeae30a0a6c0..601a65333d026 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" namespace mlir { @@ -47,7 +48,8 @@ namespace vector { /// Progressively lower a `vector.contract` with row-major matmul semantics to /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`. void populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options, + RewritePatternSet &patterns, + VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit = 1, bool disableOuterProductLowering = false); /// Populate the pattern set with the following patterns: @@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, /// /// [TransposeOp2DToShuffleLowering] /// -void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, - VectorTransformsOptions options, - PatternBenefit benefit = 1); +void populateVectorTransposeLoweringPatterns( + RewritePatternSet &patterns, + VectorTransposeLowering vectorTransposeLowering, + PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index e3a81bd20212d..eb1555df5d574 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); - populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions); + populateVectorContractLoweringPatterns(patterns, vectorContractLowering); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); populateVectorInterleaveLoweringPatterns(patterns); - populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions); + populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); populateVectorMaskMaterializationPatterns(patterns, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index c56dbcca2175d..a60410d01ac57 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { // further transformations to canonicalize/cancel. { RewritePatternSet patterns(context); - auto options = vector::VectorTransformsOptions().setVectorTransposeLowering( - vector::VectorTransposeLowering::EltWise); - vector::populateVectorTransposeLoweringPatterns(patterns, options); + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransposeLowering::EltWise); vector::populateVectorShapeCastLoweringPatterns(patterns); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 241e83e234d62..20c577273d786 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( void transform::ApplyLowerContractionPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); - populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, + populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(), /*benefit=*/1, /*disableOuterProductLowering=*/true); } @@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns( void transform::ApplyLowerTransposePatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorTransposeLoweringPatterns( - patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( - getLoweringStrategy())); + vector::populateVectorTransposeLoweringPatterns(patterns, + getLoweringStrategy()); if (getAvx2LoweringStrategy()) { auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 21261478f0648..d2f60a55fb4a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -221,7 +221,7 @@ namespace { /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// This only kicks in when VectorTransformsOptions is set to Matmul and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToMatmulOpLowering : public vector::MaskableOpRewritePattern { @@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering } ContractionOpToMatmulOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, + vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) : MaskableOpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), + vectorContractLowering(vectorContractLowering), filter(std::move(constraint)) {} FailureOr @@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorContractLowering vectorContractLowering; FilterConstraintType filter; }; @@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering } ContractionOpToOuterProductOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, + vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit = 1, FilterConstraintType constraint = defaultFilter) : MaskableOpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), + vectorContractLowering(vectorContractLowering), filter(std::move(constraint)) {} FailureOr @@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorContractLowering vectorContractLowering; FilterConstraintType filter; }; @@ -329,11 +329,11 @@ class ContractionOpToDotLowering } ContractionOpToDotLowering( - vector::VectorTransformsOptions vectorTransformOptions, + vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) : MaskableOpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} + vectorContractLowering(vectorContractLowering), filter(defaultFilter) {} FailureOr matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, @@ -341,7 +341,7 @@ class ContractionOpToDotLowering private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorContractLowering vectorContractLowering; FilterConstraintType filter; }; @@ -370,11 +370,12 @@ class ContractionOpLowering return success(); } - ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - FilterConstraintType constraint = defaultFilter) + ContractionOpLowering( + vector::VectorContractLowering vectorContractLoweringOption, + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) : MaskableOpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), + vectorContractLoweringOption(vectorContractLoweringOption), filter(std::move(constraint)) {} FailureOr @@ -383,7 +384,7 @@ class ContractionOpLowering private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorContractLowering vectorContractLoweringOption; FilterConstraintType filter; // Lower one parallel dimension. FailureOr lowerParallel(PatternRewriter &rewriter, @@ -641,8 +642,7 @@ FailureOr ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( vector::ContractionOp op, MaskingOpInterface maskOp, PatternRewriter &rewriter) const { - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::OuterProduct) + if (vectorContractLowering != vector::VectorContractLowering::OuterProduct) return failure(); if (failed(filter(op))) @@ -672,8 +672,7 @@ FailureOr ContractionOpToDotLowering::matchAndRewriteMaskableOp( if (failed(filter(op))) return failure(); - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::Dot) + if (vectorContractLowering != vector::VectorContractLowering::Dot) return failure(); auto iteratorTypes = op.getIteratorTypes().getValue(); @@ -789,11 +788,11 @@ struct ContractOpToElementwise return success(); } ContractOpToElementwise( - vector::VectorTransformsOptions vectorTransformOptions, + vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit = 1, const FilterConstraintType &constraint = defaultFilter) : MaskableOpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} + vectorContractLowering(vectorContractLowering), filter(defaultFilter) {} FailureOr matchAndRewriteMaskableOp(vector::ContractionOp contractOp, @@ -806,8 +805,7 @@ struct ContractOpToElementwise if (failed(filter(contractOp))) return failure(); - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::ParallelArith) + if (vectorContractLowering != vector::VectorContractLowering::ParallelArith) return failure(); ArrayRef lhsShape = contractOp.getLhsType().getShape(); @@ -898,7 +896,7 @@ struct ContractOpToElementwise private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorContractLowering vectorContractLowering; FilterConstraintType filter; }; @@ -941,25 +939,25 @@ FailureOr ContractionOpLowering::matchAndRewriteMaskableOp( // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); - ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); + ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx); FailureOr newVal1 = pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal1)) return newVal1; - ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); + ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx); FailureOr newVal2 = pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal2)) return newVal2; - ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); + ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx); FailureOr newVal3 = pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal3)) return newVal3; - ContractOpToElementwise pat4(vectorTransformOptions, ctx); + ContractOpToElementwise pat4(vectorContractLoweringOption, ctx); FailureOr newVal4 = pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter); if (!failed(newVal4)) @@ -1292,8 +1290,7 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( if (maskOp) return failure(); - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::Matmul) + if (vectorContractLowering != vector::VectorContractLowering::Matmul) return failure(); if (failed(filter(op))) return failure(); @@ -1382,13 +1379,14 @@ FailureOr ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( } // namespace void mlir::vector::populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options, - PatternBenefit benefit, bool disableOuterProductLowering) { + RewritePatternSet &patterns, + VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit, + bool disableOuterProductLowering) { if (!disableOuterProductLowering) patterns.add(patterns.getContext(), benefit); patterns.add( - options, patterns.getContext(), benefit); + vectorContractLoweringOption, patterns.getContext(), benefit); } void mlir::vector::populateVectorOuterProductLoweringPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index fb4dee33bc5f5..732e316c93381 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, + TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions) {} + vectorTransposeLowering(vectorTransposeLowering) {} LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern { // Set up convenience transposition table. ArrayRef transp = op.getPermutation(); - if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && + if (isShuffleLike(vectorTransposeLowering) && succeeded(isTranspose2DSlice(op))) return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); // Handle a true 2-D matrix transpose differently when requested. - if (vectorTransformOptions.vectorTransposeLowering == - vector::VectorTransposeLowering::Flat && + if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat && resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { Type flattenedType = VectorType::get(resType.getNumElements(), resType.getElementType()); @@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern { private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorTransposeLowering vectorTransposeLowering; }; /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied @@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering using OpRewritePattern::OpRewritePattern; TransposeOp2DToShuffleLowering( - vector::VectorTransformsOptions vectorTransformOptions, + vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions) {} + vectorTransposeLowering(vectorTransposeLowering) {} LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { - if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering)) + if (!isShuffleLike(vectorTransposeLowering)) return rewriter.notifyMatchFailure( op, "not using vector shuffle based lowering"); @@ -487,8 +486,7 @@ class TransposeOp2DToShuffleLowering op.getVector()); Value res; - if (vectorTransformOptions.vectorTransposeLowering == - VectorTransposeLowering::Shuffle16x16 && + if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 && m == 16 && n == 16) { reshInput = rewriter.create(loc, reshInputType, reshInput); @@ -506,15 +504,15 @@ class TransposeOp2DToShuffleLowering private: /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; + vector::VectorTransposeLowering vectorTransposeLowering; }; } // namespace void mlir::vector::populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options, - PatternBenefit benefit) { + RewritePatternSet &patterns, + VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); patterns.add( - options, patterns.getContext(), benefit); + vectorTransposeLowering, patterns.getContext(), benefit); } diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir new file mode 100644 index 0000000000000..d641c715ad74e --- /dev/null +++ b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s + +// Simple regression test that ensures ConvertVectorToLLVMPass options remain +// serialisable. We don't need to actually parse any IR to print the pass +// options. We just need to provide --dump-pass-pipeline + +// CHECK: builtin.module( +// CHECK-SAME: convert-vector-to-llvm{ +// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}} +// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}} +// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}} +// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}} +// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}} +// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}} +// CHECK-SAME: vector-contract-lowering={{[aA-zZ0-9]+}} +// CHECK-SAME: vector-transpose-lowering={{[aA-zZ0-9]+}}}) From 4b8967bb44fab17130005aceb8d1e4742b357bcc Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Tue, 25 Feb 2025 17:23:25 +0000 Subject: [PATCH 2/9] fixup: Reference vectorContractOptions instead of VectorTransformOptions --- .../Dialect/Vector/Transforms/LowerVectorContract.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index d2f60a55fb4a6..6527062009df0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -221,7 +221,7 @@ namespace { /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // -/// This only kicks in when VectorTransformsOptions is set to Matmul and +/// This only kicks in when vectorContractLowering is set to Matmul and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToMatmulOpLowering : public vector::MaskableOpRewritePattern { @@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 /// ``` /// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// This only kicks in when vectorContractLowering is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. class ContractionOpToOuterProductOpLowering : public MaskableOpRewritePattern { @@ -636,7 +636,7 @@ struct UnrolledOuterProductGenerator /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 /// ``` /// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct but +/// This only kicks in when vectorContractLowering is set to OuterProduct but /// otherwise supports any layout permutation of the matrix-multiply. FailureOr ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp( @@ -911,7 +911,7 @@ struct ContractOpToElementwise /// until a pure contraction is reached (no free/batch dimensions), /// which is replaced by a dot-product. /// -/// This only kicks in when either VectorTransformsOptions is set +/// This only kicks in when either vectorContractLoweringOption is set /// to DOT or when other contraction patterns fail. // // TODO: break down into transpose/reshape/cast ops @@ -1278,7 +1278,7 @@ class OuterProductOpLowering : public OpRewritePattern { /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // -/// This only kicks in when VectorTransformsOptions is set to `Matmul`. +/// This only kicks in when vectorContractLowering is set to `Matmul`. /// vector.transpose operations are inserted if the vector.contract op is not a /// row-major matrix multiply. /// From 1e61aaeae281b0077e8fa7a034e5e5214915546f Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Wed, 26 Feb 2025 18:29:36 +0000 Subject: [PATCH 3/9] fixup: add check that non-default options are set correctly --- .../Conversion/VectorToLLVM/test-serialisable.mlir | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir index d641c715ad74e..07e221c5ee327 100644 --- a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir +++ b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir @@ -1,9 +1,12 @@ -// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s - // Simple regression test that ensures ConvertVectorToLLVMPass options remain // serialisable. We don't need to actually parse any IR to print the pass // options. We just need to provide --dump-pass-pipeline +// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT + +// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \ +// RUN: --dump-pass-pipeline 2>&1 | FileCheck %s --check-prefix=CHANGED + // CHECK: builtin.module( // CHECK-SAME: convert-vector-to-llvm{ // CHECK-SAME: enable-amx={{[aA-zZ0-9]+}} @@ -12,5 +15,7 @@ // CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}} // CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}} // CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}} -// CHECK-SAME: vector-contract-lowering={{[aA-zZ0-9]+}} -// CHECK-SAME: vector-transpose-lowering={{[aA-zZ0-9]+}}}) +// DEFAULT: vector-contract-lowering=dot +// DEFAULT: vector-transpose-lowering=eltwise +// CHANGED: vector-contract-lowering=matmul +// CHANGED: vector-transpose-lowering=flat From 4cc7f7d601c7baac896cfb6e119e18d66057f386 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Wed, 26 Feb 2025 21:39:38 +0000 Subject: [PATCH 4/9] fixup: Use the clearer NON-DEFAULT check-prefix --- mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir index 07e221c5ee327..3757c9c685ca8 100644 --- a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir +++ b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir @@ -5,7 +5,7 @@ // RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT // RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \ -// RUN: --dump-pass-pipeline 2>&1 | FileCheck %s --check-prefix=CHANGED +// RUN: --dump-pass-pipeline 2>&1 | FileCheck %s --check-prefix=NON-DEFAULT // CHECK: builtin.module( // CHECK-SAME: convert-vector-to-llvm{ @@ -17,5 +17,5 @@ // CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}} // DEFAULT: vector-contract-lowering=dot // DEFAULT: vector-transpose-lowering=eltwise -// CHANGED: vector-contract-lowering=matmul -// CHANGED: vector-transpose-lowering=flat +// NON-DEFAULT: vector-contract-lowering=matmul +// NON-DEFAULT: vector-transpose-lowering=flat From d2cca0cb46fb8338ba79b72594cb36e8b3af1d61 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Thu, 27 Feb 2025 22:04:03 +0000 Subject: [PATCH 5/9] fixup: correct vector.matmul typo to vector.matrix_multiply --- .../lib/Dialect/Vector/Transforms/LowerVectorContract.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 6527062009df0..c74d0622b3828 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -215,11 +215,11 @@ namespace { /// ``` /// %flattened_a = vector.shape_cast %a /// %flattened_b = vector.shape_cast %b -/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b /// %d = vector.shape_cast %%flattened_d /// %e = add %c, %d /// ``` -/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. // /// This only kicks in when vectorContractLowering is set to Matmul and /// the vector.contract op is a row-major matrix multiply. @@ -1271,12 +1271,12 @@ class OuterProductOpLowering : public OpRewritePattern { /// %mtb = maybe_transpose /// %flattened_a = vector.shape_cast %mta /// %flattened_b = vector.shape_cast %mtb -/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b /// %mtd = vector.shape_cast %flattened_d /// %d = maybe_untranspose %mtd /// %e = add %c, %d /// ``` -/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. // /// This only kicks in when vectorContractLowering is set to `Matmul`. /// vector.transpose operations are inserted if the vector.contract op is not a From 0523d0309b30d495052b296b1238f077aab99a50 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Thu, 27 Feb 2025 22:04:24 +0000 Subject: [PATCH 6/9] fixup: remove unused tablegen include --- mlir/include/mlir/Conversion/Passes.td | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 606a38f7d98eb..f05947a6734c4 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -10,7 +10,6 @@ #define MLIR_CONVERSION_PASSES include "mlir/Pass/PassBase.td" -include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" //===----------------------------------------------------------------------===// // ToLLVM From e3c44b489e9ce2a8964fd9fb33db4c8905616532 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Thu, 27 Feb 2025 22:05:31 +0000 Subject: [PATCH 7/9] fixup: remove redundnancy from lit test filename --- .../{test-serialisable.mlir => pass-option-serialization.mlir} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mlir/test/Conversion/VectorToLLVM/{test-serialisable.mlir => pass-option-serialization.mlir} (100%) diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir similarity index 100% rename from mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir rename to mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir From ad267775fb58090a4693c19d97e720b648aa9ba2 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Thu, 27 Feb 2025 22:11:18 +0000 Subject: [PATCH 8/9] fixup: add documentation justifying the serialization test --- .../VectorToLLVM/pass-option-serialization.mlir | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir index 3757c9c685ca8..ebf06c57a1b3b 100644 --- a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir +++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir @@ -1,6 +1,15 @@ -// Simple regression test that ensures ConvertVectorToLLVMPass options remain -// serialisable. We don't need to actually parse any IR to print the pass -// options. We just need to provide --dump-pass-pipeline +// Ensure that ConvertVectorToLLVMPass options remain serialisable. + +// This test also allows us to exercise these options (to some extent) even if we +// don't use them in other Vector to LLVM conversion tests. This is quite relevant +// for the `Vector` Dialect (and `--convert-vector-to-llvm` pass) as in many cases +// we use the Transform Dialect (TD) rather than `--convert-vector-to-llvm` for +// testing. So here we don't check the correctness of the passes, as they're +// covered by other tests that use TD, but we still provide some test coverage of +// these pass options. + +// We don't need to actually parse any IR to print the pass options. We just need +// to provide --dump-pass-pipeline // RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT From c65d80365873a984b3dd57be4a893b4d0d9c2336 Mon Sep 17 00:00:00 2001 From: Artemiy Bulavin Date: Fri, 28 Feb 2025 14:49:32 +0000 Subject: [PATCH 9/9] fixup: add back include for VectorTransformOptions.summary --- mlir/include/mlir/Conversion/Passes.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index f05947a6734c4..b3ab4069f4ff1 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_PASSES include "mlir/Pass/PassBase.td" +include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" //===----------------------------------------------------------------------===// // ToLLVM @@ -1413,7 +1414,7 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "vector::VectorContractLowering", /*default=*/"vector::VectorContractLowering::Dot", VectorContractLoweringAttr.summary, [{::llvm::cl::values( - clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot", + clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot", "Progressively lower to finer grained `vector.contract` and dot-products. (default)"), clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul", "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),