From 2b147c8c71e26fed679b978766ea34899a6ba0d4 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Thu, 25 Jul 2024 20:49:20 +0000 Subject: [PATCH 1/4] [mlir][spirv] Support vector.step in vector to spirv conversion --- .../VectorToSPIRV/VectorToSPIRV.cpp | 41 ++++++++++++++++++- .../VectorToSPIRV/vector-to-spirv.mlir | 28 +++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 527fbe5cf628a..8b5789f9e8497 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final } }; +struct VectorStepOpConvert final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *getTypeConverter(); + Type dstType = typeConverter.convertType(stepOp.getType()); + if (!dstType) + return failure(); + + Location loc = stepOp.getLoc(); + int64_t numElements = stepOp.getType().getNumElements(); + auto intType = + rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth()); + + // Input vectors of size 1 are converted to scalars by the type converter. + // We just create a constant in this case. + if (numElements == 1) { + Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter); + rewriter.replaceOp(stepOp, zero); + return success(); + } + + SmallVector source; + for (int64_t i = 0; i < numElements; ++i) { + Attribute intAttr = rewriter.getIntegerAttr(intType, i); + Value constOp = rewriter.create(loc, intType, intAttr); + source.push_back(constOp); + } + rewriter.replaceOpWithNewOp(stepOp, dstType, + source); + return success(); + } +}; + } // namespace #define CL_INT_MAX_MIN_OPS \ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp @@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, VectorReductionFloatMinMax, VectorShapeCast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( - typeConverter, patterns.getContext(), PatternBenefit(1)); + VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter, + VectorStepOpConvert>(typeConverter, patterns.getContext(), + PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index edad208749930..016c9e141a712 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector) -> vector<1xf32> { // ----- +// CHECK-LABEL: @step +// CHECK-SAME: () +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32 +// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32 +// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32 +// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex> +// CHECK: return %[[CAST]] : vector<4xindex> +func.func @step() -> vector<4xindex> { + %0 = vector.step : vector<4xindex> + return %0 : vector<4xindex> +} + +// ----- + +// CHECK-LABEL: @step_size1 +// CHECK-SAME: () +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex> +// CHECK: return %[[CAST]] : vector<1xindex> +func.func @step_size1() -> vector<1xindex> { + %0 = vector.step : vector<1xindex> + return %0 : vector<1xindex> +} + +// ----- + module attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> From a33b5f1a199357fe8a0c4799791aa0ca9f163e20 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Thu, 25 Jul 2024 17:04:06 -0400 Subject: [PATCH 2/4] Update mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir Co-authored-by: Jakub Kuderski --- mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 016c9e141a712..a684643f713d3 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -794,8 +794,7 @@ func.func @shape_cast_size1_vector(%arg0 : vector) -> vector<1xf32> { // ----- -// CHECK-LABEL: @step -// CHECK-SAME: () +// CHECK-LABEL: @step() // CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 // CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32 // CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32 From 9ad98c62042cf38d5f9bfe1b231dc71585d043e5 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Thu, 25 Jul 2024 17:04:12 -0400 Subject: [PATCH 3/4] Update mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir Co-authored-by: Jakub Kuderski --- mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index a684643f713d3..dd0ed77470a25 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -809,8 +809,7 @@ func.func @step() -> vector<4xindex> { // ----- -// CHECK-LABEL: @step_size1 -// CHECK-SAME: () +// CHECK-LABEL: @step_size1() // CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex> // CHECK: return %[[CAST]] : vector<1xindex> From 7d94be73c99f82720e6371eda9ee0b3a2b6edfd6 Mon Sep 17 00:00:00 2001 From: Angel Zhang Date: Thu, 25 Jul 2024 17:13:37 -0400 Subject: [PATCH 4/4] Update mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp Co-authored-by: Jakub Kuderski --- mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 8b5789f9e8497..890706bf1bb2e 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -931,6 +931,7 @@ struct VectorStepOpConvert final : OpConversionPattern { } SmallVector source; + source.reserve(numElements); for (int64_t i = 0; i < numElements; ++i) { Attribute intAttr = rewriter.getIntegerAttr(intType, i); Value constOp = rewriter.create(loc, intType, intAttr);