diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 0f42ffb3a8026..7600d723c3093 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1440,7 +1440,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"vectorTransformsOptions", "vector-transform-options", + "vector::VectorTransformsOptions", + /*default=*/"vector::VectorTransformsOptions()", + "Options to lower some operations like contractions and transposes.">, ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h index 4661d31b6364d..410b881db7959 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVMPASS_H_ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" namespace mlir { class Pass; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 2c4c5ada9815d..e3a81bd20212d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -69,12 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); - populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); + populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); populateVectorInterleaveLoweringPatterns(patterns); - populateVectorTransposeLoweringPatterns(patterns, - VectorTransformsOptions()); + populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); populateVectorMaskMaterializationPatterns(patterns,