Skip to content

Commit 19695b4

Browse files
committed
[MLIR][Vector]Generalize DropUnitDimFromElementwiseOps
1 parent 3efaf9c commit 19695b4

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,24 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16071607
}
16081608
};
16091609

1610-
/// For vectors with either leading or trailing unit dim, replaces:
1610+
FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
1611+
VectorType newVT = VT;
1612+
int removed = 0;
1613+
auto shape = VT.getShape();
1614+
for (unsigned i = 0; i < shape.size(); i++) {
1615+
if (shape[i] == 1 && !VT.getScalableDims()[i]) {
1616+
newVT = VectorType::Builder(newVT).dropDim(i - removed);
1617+
removed++;
1618+
}
1619+
}
1620+
1621+
if (removed == 0)
1622+
return failure();
1623+
return newVT;
1624+
}
1625+
1626+
1627+
/// For vectors with at least an unit dim, replaces:
16111628
/// elementwise(a, b)
16121629
/// with:
16131630
/// sc_a = shape_cast(a)
@@ -1641,7 +1658,9 @@ struct DropUnitDimFromElementwiseOps final
16411658
using OpTraitRewritePattern::OpTraitRewritePattern;
16421659
LogicalResult matchAndRewrite(Operation *op,
16431660
PatternRewriter &rewriter) const override {
1644-
if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1661+
if (op->getNumResults() != 1)
1662+
return failure();
1663+
if (op->getNumRegions() != 0)
16451664
return failure();
16461665

16471666
auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
@@ -1652,42 +1671,30 @@ struct DropUnitDimFromElementwiseOps final
16521671
// guaranteed to have identical shapes (with some exceptions such as
16531672
// `arith.select`) and it suffices to only check one of them.
16541673
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1655-
if (!sourceVectorType)
1656-
return failure();
1657-
if (sourceVectorType.getRank() < 2)
1658-
return failure();
1659-
1660-
bool hasTrailingDimUnitFixed =
1661-
((sourceVectorType.getShape().back() == 1) &&
1662-
(!sourceVectorType.getScalableDims().back()));
1663-
bool hasLeadingDimUnitFixed =
1664-
((sourceVectorType.getShape().front() == 1) &&
1665-
(!sourceVectorType.getScalableDims().front()));
1666-
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1674+
if (!sourceVectorType || sourceVectorType.getRank() < 2)
16671675
return failure();
16681676

1669-
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1670-
// operands
1671-
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1672-
16731677
SmallVector<Value> newOperands;
16741678
auto loc = op->getLoc();
16751679
for (auto operand : op->getOperands()) {
16761680
auto opVectorType = cast<VectorType>(operand.getType());
1677-
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
1678-
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1681+
auto newVType = dropNonScalableUnitDimType(opVectorType);
1682+
if (failed(newVType)) {
1683+
return failure();
1684+
}
1685+
auto opSC =
1686+
rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
16791687
newOperands.push_back(opSC);
16801688
}
16811689

16821690
VectorType newResultVectorType =
1683-
VectorType::Builder(resultVectorType).dropDim(dim);
1684-
// Create an updated elementwise Op without leading/trailing unit dim
1691+
dropNonScalableUnitDimType(resultVectorType).value();
1692+
// Create an updated elementwise Op without unit dim
16851693
Operation *elementwiseOp =
16861694
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
16871695
newResultVectorType, op->getAttrs());
16881696

1689-
// Restore the leading/trailing unit dim by applying vector.shape_cast
1690-
// to the result
1697+
// Restore the unit dim by applying vector.shape_cast to the result
16911698
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
16921699
elementwiseOp->getResult(0));
16931700

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,26 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
459459
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
460460
// CHECK-128B-NOT: memref.collapse_shape
461461

462+
// -----
463+
464+
func.func @fold_unit_center_dim_scalable(%arg0 : vector<8x1x[1]xf128>,
465+
%arg1 : vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
466+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]xf128> to vector<8x1x[1]xf128>
467+
%add = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]xf128>
468+
%res = vector.shape_cast %add : vector<8x1x[1]xf128> to vector<8x[1]xf128>
469+
return %res : vector<8x[1]xf128>
470+
}
471+
472+
// CHECK-LABEL: func.func @fold_unit_center_dim_scalable(
473+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]xf128>,
474+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
475+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]xf128> to vector<8x[1]xf128>
476+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]xf128> to vector<8x[1]xf128>
477+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]xf128>
478+
// CHECK: return %[[VAL_4]] : vector<8x[1]xf128>
479+
480+
481+
462482

463483
// -----
464484

0 commit comments

Comments
 (0)