diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 527fbe5cf628a..890706bf1bb2e 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final } }; +struct VectorStepOpConvert final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *getTypeConverter(); + Type dstType = typeConverter.convertType(stepOp.getType()); + if (!dstType) + return failure(); + + Location loc = stepOp.getLoc(); + int64_t numElements = stepOp.getType().getNumElements(); + auto intType = + rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth()); + + // Input vectors of size 1 are converted to scalars by the type converter. + // We just create a constant in this case. + if (numElements == 1) { + Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter); + rewriter.replaceOp(stepOp, zero); + return success(); + } + + SmallVector source; + source.reserve(numElements); + for (int64_t i = 0; i < numElements; ++i) { + Attribute intAttr = rewriter.getIntegerAttr(intType, i); + Value constOp = rewriter.create(loc, intType, intAttr); + source.push_back(constOp); + } + rewriter.replaceOpWithNewOp(stepOp, dstType, + source); + return success(); + } +}; + } // namespace #define CL_INT_MAX_MIN_OPS \ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp @@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, VectorReductionFloatMinMax, VectorShapeCast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>( - typeConverter, patterns.getContext(), PatternBenefit(1)); + VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter, + VectorStepOpConvert>(typeConverter, patterns.getContext(), + PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index edad208749930..dd0ed77470a25 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -794,6 +794,32 @@ func.func @shape_cast_size1_vector(%arg0 : vector) -> vector<1xf32> { // ----- +// CHECK-LABEL: @step() +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32 +// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32 +// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32 +// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex> +// CHECK: return %[[CAST]] : vector<4xindex> +func.func @step() -> vector<4xindex> { + %0 = vector.step : vector<4xindex> + return %0 : vector<4xindex> +} + +// ----- + +// CHECK-LABEL: @step_size1() +// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32 +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex> +// CHECK: return %[[CAST]] : vector<1xindex> +func.func @step_size1() -> vector<1xindex> { + %0 = vector.step : vector<1xindex> + return %0 : vector<1xindex> +} + +// ----- + module attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>>