Skip to content

Conversation

@dcaballe
Copy link
Contributor

Only elements with at least one use are lowered to llvm.extractelement op.

@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

Only elements with at least one use are lowered to llvm.extractelement op.


Full diff: https://github.com/llvm/llvm-project/pull/145766.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+33-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+35-1)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d53d11f87efe8..f1543200fb56f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
   }
 };
 
+/// Conversion pattern for a `vector.to_elements`.
+struct VectorToElementsLowering
+    : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = toElementsOp.getLoc();
+    auto idxType = typeConverter->convertType(rewriter.getIndexType());
+    Value source = adaptor.getSource();
+
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an extractelement operation only for results that are not dead.
+      if (!element.use_empty()) {
+        auto constIdx = rewriter.create<LLVM::ConstantOp>(
+            loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+        auto llvmType = typeConverter->convertType(element.getType());
+
+        Value result = rewriter.create<LLVM::ExtractElementOp>(
+            loc, llvmType, source, constIdx);
+        results[idx] = result;
+      }
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.step.
 struct VectorScalableStepOpLowering
     : public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
-               VectorScalableStepOpLowering>(converter);
+               VectorToElementsLowering, VectorScalableStepOpLowering>(
+      converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 3df14528bac39..8f73e79d7bfc2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
 // CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
-// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
 // CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
 // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //      CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //  CHECK-NOT:   llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //  CHECK-NOT:   llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.step
 //===----------------------------------------------------------------------===//

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % minor nits, thanks!

Comment on lines 2424 to 2458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Add block comment, remove vector from function names (for consistency)

Suggested change
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func.func @vector_to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func.func @to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe do you plan to add SPIR-V lowering on your own? If not, could you open an issue and tag me so that I can assign it to someone?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can invert this check and continue instead to reduce nesting

@dcaballe
Copy link
Contributor Author

@dcaballe do you plan to add SPIR-V lowering on your own? If not, could you open an issue and tag me so that I can assign it to someone?

#145929

dcaballe added 2 commits June 26, 2025 17:08
Only elements with at least one use are lowered to `llvm.extractelement` op.
@dcaballe dcaballe force-pushed the vector-to-elements-llvm-lowering branch from 2c40823 to cc50307 Compare June 26, 2025 17:25
@dcaballe dcaballe merged commit 7842e9e into llvm:main Jun 26, 2025
5 of 7 checks passed
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Only elements with at least one use are lowered to `llvm.extractelement`
op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants