-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Vector] Make elementwise-on-broadcast sinking handle splat consts #150867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Make elementwise-on-broadcast sinking handle splat consts #150867
Conversation
There is a pattern that rewrites elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to broadcast(elementwise_op(x1, x2, ...) : T to U). This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a `vector.broadcast` or `vector.splat` operation. This patch fixes this oversight, prenting premature broadcasting. This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Krzysztof Drewniak (krzysz00) ChangesThere is a pattern that rewrites This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a This patch fixes this oversight, prenting premature broadcasting. This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding. Full diff: https://github.com/llvm/llvm-project/pull/150867.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fef904fa..c51c7b7270fae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1005,26 +1005,39 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ // Get the type of the first non-constant operand
+ Operation *firstBroadcastOrSplat = nullptr;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
+ return failure();
+ firstBroadcastOrSplat = definingOp;
+ break;
+ }
+ if (!firstBroadcastOrSplat)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type firstBroadcastOrSplatType =
+ firstBroadcastOrSplat->getOperand(0).getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
- })) {
+ if (!llvm::all_of(
+ op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
+ if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+ return (bcastOp.getOperand().getType() ==
+ firstBroadcastOrSplatType);
+ if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+ return (splatOp.getOperand().getType() ==
+ firstBroadcastOrSplatType);
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
+ })) {
return failure();
}
@@ -1032,13 +1045,28 @@ struct ReorderElementwiseOpsOnBroadcast final
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
+ newConst = splatConst.resizeSplat(shapedTy);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, firstBroadcastOrSplatType,
+ operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ firstBroadcastOrSplatType, op->getAttrs());
// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee8929dc3f3..d7722eac2b91f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdca134e6..f8638ab843ecb 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
return %r : vector<2x[4]xi32>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.subi %cst, %0 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+ %2 = arith.mulf %0, %cst : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( | ||
| // CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { | ||
| // CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index | ||
| // CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index | ||
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> | ||
| // CHECK: return %[[BCAST]] : vector<1x4xindex> | ||
|
|
||
| func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { | ||
| %0 = vector.broadcast %arg0 : index to vector<1x4xindex> | ||
| %cst = arith.constant dense<2> : vector<1x4xindex> | ||
| %2 = arith.addi %0, %cst : vector<1x4xindex> | ||
| return %2 : vector<1x4xindex> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( | ||
| // CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { | ||
| // CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index | ||
| // CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index | ||
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> | ||
| // CHECK: return %[[BCAST]] : vector<1x4xindex> | ||
|
|
||
| func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { | ||
| %0 = vector.broadcast %arg0 : index to vector<1x4xindex> | ||
| %cst = arith.constant dense<2> : vector<1x4xindex> | ||
| %2 = arith.subi %cst, %0 : vector<1x4xindex> | ||
| return %2 : vector<1x4xindex> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @broadcast_vector_and_splat_const( | ||
| // CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { | ||
| // CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> | ||
| // CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> | ||
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> | ||
| // CHECK: return %[[BCAST]] : vector<3x4xf32> | ||
|
|
||
| func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { | ||
| %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> | ||
| %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> | ||
| %2 = arith.mulf %0, %cst : vector<3x4xf32> | ||
| return %2 : vector<3x4xf32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( | ||
| // CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { | ||
| // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> | ||
| // CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> | ||
| // CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> | ||
| // CHECK: return %[[ADD]] : vector<1x4xindex> | ||
|
|
||
| func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { | ||
| %0 = vector.broadcast %arg0 : index to vector<1x4xindex> | ||
| %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> | ||
| %2 = arith.addi %0, %cst : vector<1x4xindex> | ||
| return %2 : vector<1x4xindex> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Could you move these tests near other tests for ReorderElementwiseOpsOnBroadcast, i.e. here:
llvm-project/mlir/test/Dialect/Vector/vector-sink.mlir
Lines 4 to 211 in 860b1e6
| //----------------------------------------------------------------------------- | |
| // [Pattern: ReorderElementwiseOpsOnBroadcast] | |
| //----------------------------------------------------------------------------- | |
| // CHECK-LABEL: func.func @broadcast_scalar_with_bcast( | |
| // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> { | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> | |
| // CHECK: return %[[BCAST]] : vector<1x4xindex> | |
| func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> { | |
| %0 = vector.broadcast %arg1 : index to vector<1x4xindex> | |
| %1 = vector.broadcast %arg2 : index to vector<1x4xindex> | |
| %2 = arith.addi %0, %1 : vector<1x4xindex> | |
| return %2 : vector<1x4xindex> | |
| } | |
| // CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable( | |
| // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> { | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex> | |
| // CHECK: return %[[BCAST]] : vector<1x[4]xindex> | |
| func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> { | |
| %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex> | |
| %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex> | |
| %2 = arith.addi %0, %1 : vector<1x[4]xindex> | |
| return %2 : vector<1x[4]xindex> | |
| } | |
| // ----- | |
| // CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat( | |
| // CHECK-SAME: %[[ARG1:.*]]: index, | |
| // CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> { | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> | |
| // CHECK: return %[[BCAST]] : vector<1x4xindex> | |
| func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> { | |
| %0 = vector.splat %arg1 : vector<1x4xindex> | |
| %1 = vector.broadcast %arg2 : index to vector<1x4xindex> | |
| %2 = arith.addi %0, %1 : vector<1x4xindex> | |
| return %2 : vector<1x4xindex> | |
| } | |
| // CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable( | |
| // CHECK-SAME: %[[ARG1:.*]]: index, | |
| // CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> { | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex> | |
| // CHECK: return %[[BCAST]] : vector<1x[4]xindex> | |
| func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> { | |
| %0 = vector.splat %arg1 : vector<1x[4]xindex> | |
| %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex> | |
| %2 = arith.addi %0, %1 : vector<1x[4]xindex> | |
| return %2 : vector<1x[4]xindex> | |
| } | |
| // ----- | |
| // CHECK-LABEL: func.func @broadcast_vector( | |
| // CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>, | |
| // CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> { | |
| // CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32> | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32> | |
| // CHECK: return %[[BCAST]] : vector<3x4xf32> | |
| func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> { | |
| %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32> | |
| %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32> | |
| %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32> | |
| return %2 : vector<3x4xf32> | |
| } | |
| // CHECK-LABEL: func.func @broadcast_vector_scalable( | |
| // CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>, | |
| // CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> { | |
| // CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32> | |
| // CHECK: return %[[BCAST]] : vector<3x[4]xf32> | |
| func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> { | |
| %arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32> | |
| %arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32> | |
| %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32> | |
| return %2 : vector<3x[4]xf32> | |
| } | |
| // ----- | |
| // CHECK-LABEL: func.func @broadcast_scalar_and_vec( | |
| // CHECK-SAME: %[[ARG1:.*]]: index, | |
| // CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> { | |
| // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex> | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex> | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex> | |
| // CHECK: return %[[ADD]] : vector<1x4xindex> | |
| func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> { | |
| %0 = vector.splat %arg1 : vector<1x4xindex> | |
| %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex> | |
| %2 = arith.addi %0, %1 : vector<1x4xindex> | |
| return %2 : vector<1x4xindex> | |
| } | |
| // CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable( | |
| // CHECK-SAME: %[[ARG1:.*]]: index, | |
| // CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> { | |
| // CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex> | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex> | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex> | |
| // CHECK: return %[[ADD]] : vector<1x[4]xindex> | |
| func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> { | |
| %0 = vector.splat %arg1 : vector<1x[4]xindex> | |
| %1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex> | |
| %2 = arith.addi %0, %1 : vector<1x[4]xindex> | |
| return %2 : vector<1x[4]xindex> | |
| } | |
| // ----- | |
| // CHECK-LABEL: func.func @broadcast_vector_and_scalar( | |
| // CHECK-SAME: %[[ARG_0:.*]]: i32, | |
| // CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> { | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32> | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32> | |
| // CHECK: return %[[ADD]] : vector<4xi32> | |
| func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> { | |
| %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32> | |
| %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32> | |
| return %2 : vector<4xi32> | |
| } | |
| // CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable( | |
| // CHECK-SAME: %[[ARG_0:.*]]: i32, | |
| // CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> { | |
| // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32> | |
| // CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32> | |
| // CHECK: return %[[ADD]] : vector<[4]xi32> | |
| func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> { | |
| %arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32> | |
| %2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32> | |
| return %2 : vector<[4]xi32> | |
| } | |
| // ----- | |
| #matmat_accesses = [ | |
| affine_map<(i, j, k) -> (i, k)>, | |
| affine_map<(i, j, k) -> (k, j)>, | |
| affine_map<(i, j, k) -> (i, j)> | |
| ] | |
| #matmat_trait = { | |
| indexing_maps = #matmat_accesses, | |
| iterator_types = ["parallel", "parallel", "reduction"] | |
| } | |
| // CHECK-LABEL: func.func @negative_not_elementwise | |
| // CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32> | |
| // CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32> | |
| // CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32> | |
| // CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> | |
| func.func @negative_not_elementwise() -> vector<2x2xf32> { | |
| %f1 = arith.constant 1.0: f32 | |
| %f2 = arith.constant 2.0: f32 | |
| %f3 = arith.constant 3.0: f32 | |
| %A = vector.broadcast %f1 : f32 to vector<2x2xf32> | |
| %B = vector.broadcast %f2 : f32 to vector<2x2xf32> | |
| %C = vector.broadcast %f3 : f32 to vector<2x2xf32> | |
| %res = vector.contract #matmat_trait %A, %B, %C | |
| : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> | |
| return %res : vector<2x2xf32> | |
| } | |
| // ----- | |
| // The source and the result for arith.cmp have different types - not supported | |
| // CHECK-LABEL: func.func @negative_source_and_result_mismatch | |
| // CHECK: %[[BROADCAST:.+]] = vector.broadcast | |
| // CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] | |
| // CHECK: return %[[RETURN]] | |
| func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { | |
| %0 = vector.broadcast %arg0 : f32 to vector<1xf32> | |
| %1 = arith.cmpf uno, %0, %0 : vector<1xf32> | |
| return %1 : vector<1xi1> | |
| } | |
| // ----- | |
| // vector.fma only supports vectors - currently it's not possible to replace this with e.g.: | |
| // %scalar_res = vector.fma %scalar_1, %scalar2 | |
| // %vec_res = vector.broadcast %scalar_res | |
| // | |
| // TODO: It should be possible to support this case | |
| // CHECK-LABEL: func.func @negative_op_only_supports_vectors | |
| // CHECK: %[[BROADCAST:.+]] = vector.broadcast | |
| // CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]] | |
| // CHECK: return %[[RESULT]] | |
| func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { | |
| %0 = vector.broadcast %arg0 : f32 to vector<1xf32> | |
| %1 = vector.fma %0, %0, %0 : vector<1xf32> | |
| return %1 : vector<1xf32> | |
| } |
Thanks and sorry for not taking a look earlier!
|
I don't have a clean repro now, but it breaks IREE: iree-org/iree#21522 The assertion error is: |
…plat consts (llvm#150867)" This reverts commit 330a7e1.
This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types. PR llvm#150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.
…#151274) This patch extends the operation that rewrites elementwise operations whose inputs are all broadcast from the same shape to handle mixed-types, such as when the result and input types don't match, or when the inputs have multiple types. PR #150867 failed to check for the possibility of type mismatches when rewriting splat constants. In order to fix that issue, we add support for mixed-type operations more generally.
There is a pattern that rewrites
elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to broadcast(elementwise_op(x1, x2, ...) : T to U).
This pattern did not, however, account for the case where a broadcast constant is represented as a SplatElementsAttr, which can safely be reshaped or scalarized but is not a
vector.broadcastorvector.splatoperation.This patch fixes this oversight, prenting premature broadcasting.
This did result in the need to update some linalg dialect tests, which now feature a less-broadcast computation and/or more constant folding.