diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h index 5fda62e3584c7..1e29bfeb9c392 100644 --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -24,6 +24,9 @@ void populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions = false, bool force32BitVectorIndices = false); +namespace vector { +void registerConvertVectorToLLVMInterface(DialectRegistry ®istry); +} } // namespace mlir #endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 14a6a2787b3a5..887db344ed88b 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -26,6 +26,7 @@ #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" @@ -76,6 +77,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); + vector::registerConvertVectorToLLVMInterface(registry); // Register all transform dialect extensions. affine::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index baed98c13adc7..df22de97ffb40 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -1942,3 +1943,27 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns( patterns.add(converter); patterns.add(converter); } + +namespace { +struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateVectorToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::vector::registerConvertVectorToLLVMInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6a329499c7110..a4a6f90f7e51e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -430,6 +431,7 @@ void VectorDialect::initialize() { TransferWriteOp>(); declarePromisedInterface(); declarePromisedInterface(); + declarePromisedInterface(); } /// Materialize a single constant operation from a given attribute value with diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir new file mode 100644 index 0000000000000..5252bb25ecab5 --- /dev/null +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -0,0 +1,14 @@ +// Most of the vector lowering is tested in vector-to-llvm.mlir, this file only for the interface smoke test +// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector" --split-input-file %s | FileCheck %s + +func.func @bitcast_f32_to_i32_vector_0d(%arg0: vector) -> vector { + %0 = vector.bitcast %arg0 : vector to vector + return %0 : vector +} + +// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d +// CHECK-SAME: %[[ARG_0:.*]]: vector +// CHECK: %[[VEC_F32_1D:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : vector to vector<1xf32> +// CHECK: %[[VEC_I32_1D:.*]] = llvm.bitcast %[[VEC_F32_1D]] : vector<1xf32> to vector<1xi32> +// CHECK: %[[VEC_I32_0D:.*]] = builtin.unrealized_conversion_cast %[[VEC_I32_1D]] : vector<1xi32> to vector +// CHECK: return %[[VEC_I32_0D]] : vector