Skip to content

Conversation

@krzysz00
Copy link
Contributor

There is a pattern that rewrites
elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to broadcast(elementwise_op(x1, x2, ...) : T to U).

This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a vector.broadcast or vector.splat operation.

This patch fixes this oversight, prenting premature broadcasting.

This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding.

There is a pattern that rewrites
elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...)
to broadcast(elementwise_op(x1, x2, ...) : T to U).

This pattern did not, however, account for the case where a broadcast
constant is represented as a SplatElementsAttr, which can safely be reshaped
or scalarized but is not a `vector.broadcast` or `vector.splat` operation.

This patch fixes this oversight, prenting premature broadcasting.

This did result in the need to update some linalg dialect tests,
which now feature a less-broadcast computation and/or more constant
folding.
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-vector

Author: Krzysztof Drewniak (krzysz00)

Changes

There is a pattern that rewrites
elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to broadcast(elementwise_op(x1, x2, ...) : T to U).

This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a vector.broadcast or vector.splat operation.

This patch fixes this oversight, prenting premature broadcasting.

This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding.


Full diff: https://github.com/llvm/llvm-project/pull/150867.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+44-16)
  • (modified) mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir (+21-30)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+64)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fef904fa..c51c7b7270fae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1005,26 +1005,39 @@ struct ReorderElementwiseOpsOnBroadcast final
           "might be a scalar");
     }
 
-    // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+    // Get the type of the first non-constant operand
+    Operation *firstBroadcastOrSplat = nullptr;
+    for (Value operand : op->getOperands()) {
+      Operation *definingOp = operand.getDefiningOp();
+      if (!definingOp)
+        return failure();
+      if (definingOp->hasTrait<OpTrait::ConstantLike>())
+        continue;
+      if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
+        return failure();
+      firstBroadcastOrSplat = definingOp;
+      break;
+    }
+    if (!firstBroadcastOrSplat)
       return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Type firstBroadcastOrSplatType =
+        firstBroadcastOrSplat->getOperand(0).getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
-          return false;
-        })) {
+    if (!llvm::all_of(
+            op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
+              if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+                return (bcastOp.getOperand().getType() ==
+                        firstBroadcastOrSplatType);
+              if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+                return (splatOp.getOperand().getType() ==
+                        firstBroadcastOrSplatType);
+              SplatElementsAttr splatConst;
+              return matchPattern(val, m_Constant(&splatConst));
+            })) {
       return failure();
     }
 
@@ -1032,13 +1045,28 @@ struct ReorderElementwiseOpsOnBroadcast final
     SmallVector<Value> srcValues;
     srcValues.reserve(op->getNumOperands());
     for (Value operand : op->getOperands()) {
-      srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+      SplatElementsAttr splatConst;
+      if (matchPattern(operand, m_Constant(&splatConst))) {
+        Attribute newConst;
+        if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
+          newConst = splatConst.resizeSplat(shapedTy);
+        } else {
+          newConst = splatConst.getSplatValue<Attribute>();
+        }
+        Operation *newConstOp =
+            operand.getDefiningOp()->getDialect()->materializeConstant(
+                rewriter, newConst, firstBroadcastOrSplatType,
+                operand.getLoc());
+        srcValues.push_back(newConstOp->getResult(0));
+      } else {
+        srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+      }
     }
 
     // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        lhsBcastOrSplatType, op->getAttrs());
+                        firstBroadcastOrSplatType, op->getAttrs());
 
     // Replace the original Op with the elementwise Op
     auto vectorType = op->getResultTypes()[0];
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee8929dc3f3..d7722eac2b91f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
 // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
 // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
 // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
 // CHECK:    %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
 // CHECK:    %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
 // CHECK:    %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK:    %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
 // CHECK:    %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK:    %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK:    %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK:    %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK:    %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK:    %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK:    %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK:    %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK:    %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
 // CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
 // CHECK:    vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
 
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
 // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
 // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
 // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
 // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
 // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
 // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
 
 // -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK:           %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
 // CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
 // CHECK:           return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:           %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK:           %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK:           %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
 // CHECK:           %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
 // CHECK-SAME:                                                             %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                             %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:           %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK:           %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
 // CHECK:           %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_10]] : tensor<1x4xf32>
 // CHECK:         }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]
 // CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]
 // CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG:    %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG:    %[[C3:.+]] = arith.constant 3 : index
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
 // CHECK-DAG:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
 // CHECK-DAG:    %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK:        %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK:        %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK:        %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK:        %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
 // CHECK:        %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
 // CHECK:        %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
 // CHECK:        %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdca134e6..f8638ab843ecb 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
   return %r : vector<2x[4]xi32>
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<2> : vector<1x4xindex>
+  %2 = arith.addi %0, %cst : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK:           %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<2> : vector<1x4xindex>
+  %2 = arith.subi %cst, %0 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK:           %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+  %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+  %2 = arith.mulf %0, %cst : vector<3x4xf32>
+  return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG:       %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK:           return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+  %2 = arith.addi %0, %cst : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
 //===----------------------------------------------------------------------===//
 // [Pattern: ReorderElementwiseOpsOnTranspose]
 //===----------------------------------------------------------------------===//

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@krzysz00 krzysz00 merged commit 330a7e1 into llvm:main Jul 29, 2025
14 checks passed
Comment on lines +260 to +321
// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>

func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<2> : vector<1x4xindex>
%2 = arith.addi %0, %cst : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>

func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<2> : vector<1x4xindex>
%2 = arith.subi %cst, %0 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>

func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
%cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
%2 = arith.mulf %0, %cst : vector<3x4xf32>
return %2 : vector<3x4xf32>
}

// -----

// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>

func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg0 : index to vector<1x4xindex>
%cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
%2 = arith.addi %0, %cst : vector<1x4xindex>
return %2 : vector<1x4xindex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you move these tests near other tests for ReorderElementwiseOpsOnBroadcast, i.e. here:

//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
//-----------------------------------------------------------------------------
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}
// CHECK-LABEL: func.func @broadcast_vector_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
// CHECK: return %[[BCAST]] : vector<3x[4]xf32>
func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
return %2 : vector<3x[4]xf32>
}
// -----
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}
// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}
// -----
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
return %2 : vector<4xi32>
}
// CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
// CHECK: return %[[ADD]] : vector<[4]xi32>
func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
return %2 : vector<[4]xi32>
}
// -----
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#matmat_trait = {
indexing_maps = #matmat_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func.func @negative_not_elementwise
// CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
func.func @negative_not_elementwise() -> vector<2x2xf32> {
%f1 = arith.constant 1.0: f32
%f2 = arith.constant 2.0: f32
%f3 = arith.constant 3.0: f32
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
%res = vector.contract #matmat_trait %A, %B, %C
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
return %res : vector<2x2xf32>
}
// -----
// The source and the result for arith.cmp have different types - not supported
// CHECK-LABEL: func.func @negative_source_and_result_mismatch
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
// CHECK: return %[[RETURN]]
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
}
// -----
// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
// %scalar_res = vector.fma %scalar_1, %scalar2
// %vec_res = vector.broadcast %scalar_res
//
// TODO: It should be possible to support this case
// CHECK-LABEL: func.func @negative_op_only_supports_vectors
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
// CHECK: return %[[RESULT]]
func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = vector.fma %0, %0, %0 : vector<1xf32>
return %1 : vector<1xf32>
}

Thanks and sorry for not taking a look earlier!

@hanhanW
Copy link
Contributor

hanhanW commented Jul 29, 2025

I don't have a clean repro now, but it breaks IREE: iree-org/iree#21522

The assertion error is:

iree/third_party/llvm-project/mlir/lib/IR/BuiltinAttributes.cpp:1264: DenseElementsAttr mlir::DenseElementsAttr::resizeSplat(ShapedType): Assertion `newType.getElementType() == curType.getElementType() && "expected the same element t
ype"' failed.

hanhanW added a commit to iree-org/llvm-project that referenced this pull request Jul 29, 2025
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Jul 30, 2025
This patch extends the operation that rewrites elementwise operations
whose inputs are all broadcast from the same shape to handle
mixed-types, such as when the result and input types don't match, or
when the inputs have multiple types.

PR llvm#150867 failed to check for the possibility of type mismatches when
rewriting splat constants. In order to fix that issue, we add support
for mixed-type operations more generally.
krzysz00 added a commit that referenced this pull request Jul 30, 2025
…#151274)

This patch extends the operation that rewrites elementwise operations
whose inputs are all broadcast from the same shape to handle
mixed-types, such as when the result and input types don't match, or
when the inputs have multiple types.

PR #150867 failed to check for the possibility of type mismatches when
rewriting splat constants. In order to fix that issue, we add support
for mixed-type operations more generally.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants