Skip to content

Conversation

@obtuseangleAZ
Copy link
Contributor

@obtuseangleAZ obtuseangleAZ commented Jul 25, 2024

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V.
Fixes: #100602

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V. Related issue: #100602


Full diff: https://github.com/llvm/llvm-project/pull/100651.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

Changes

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V. Related issue: #100602


Full diff: https://github.com/llvm/llvm-project/pull/100651.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits

@kuhar kuhar merged commit 599a91a into llvm:main Jul 26, 2024
@obtuseangleAZ obtuseangleAZ deleted the vector-step-to-spirv branch August 8, 2024 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Support vector.step in vector to spirv conversion

3 participants