diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index de2af69eba9ec..21d8e1d9f1156 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern { } }; +struct VectorToElementOpConvert final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector results(toElementsOp->getNumResults()); + Location loc = toElementsOp.getLoc(); + + // Input vectors of size 1 are converted to scalars by the type converter. + // We cannot use `spirv::CompositeExtractOp` directly in this case. + // For a scalar source, the result is just the scalar itself. + if (isa(adaptor.getSource().getType())) { + results[0] = adaptor.getSource(); + rewriter.replaceOp(toElementsOp, results); + return success(); + } + + Type srcElementType = toElementsOp.getElements().getType().front(); + Type elementType = getTypeConverter()->convertType(srcElementType); + if (!elementType) + return rewriter.notifyMatchFailure( + toElementsOp, + llvm::formatv("failed to convert element type '{0}' to SPIR-V", + srcElementType)); + + for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) { + // Create an CompositeExtract operation only for results that are not + // dead. + if (element.use_empty()) + continue; + + Value result = rewriter.create( + loc, elementType, adaptor.getSource(), + rewriter.getI32ArrayAttr({static_cast(idx)})); + results[idx] = result; + } + + rewriter.replaceOp(toElementsOp, results); + return success(); + } +}; + } // namespace #define CL_INT_MAX_MIN_OPS \ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp @@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns( VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorFmaOpConvert, VectorFromElementsOpConvert, - VectorInsertElementOpConvert, VectorInsertOpConvert, - VectorReductionPattern, + VectorToElementOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, VectorReductionPattern, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 4701ac5d96009..99ab0e1dc4eef 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { // ----- +// CHECK-LABEL: func.func @to_elements_one_element +// CHECK-SAME: %[[A:.*]]: vector<1xf32>) +// CHECK: %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32 +// CHECK: return %[[ELEM0]] : f32 +func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) { + %0:1 = vector.to_elements %a : vector<1xf32> + return %0#0 : f32 +} + +// CHECK-LABEL: func.func @to_elements_no_dead_elements +// CHECK-SAME: %[[A:.*]]: vector<4xf32>) +// CHECK: %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32> +// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32> +// CHECK: %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32> +// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32> +// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 +func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %a : vector<4xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} + +// CHECK-LABEL: func.func @to_elements_dead_elements +// CHECK-SAME: %[[A:.*]]: vector<4xf32>) +// CHECK-NOT: spirv.CompositeExtract %[[A]][0 : i32] +// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32> +// CHECK-NOT: spirv.CompositeExtract %[[A]][2 : i32] +// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32> +// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 +func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { + %0:4 = vector.to_elements %a : vector<4xf32> + return %0#1, %0#3 : f32, f32 +} + +// ----- + // CHECK-LABEL: @from_elements_0d // CHECK-SAME: %[[ARG0:.+]]: f32 // CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]