@@ -1607,7 +1607,24 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1607
1607
}
1608
1608
};
1609
1609
1610
- // / For vectors with either leading or trailing unit dim, replaces:
1610
+ FailureOr<VectorType> dropNonScalableUnitDimType (VectorType VT) {
1611
+ VectorType newVT = VT;
1612
+ int removed = 0 ;
1613
+ auto shape = VT.getShape ();
1614
+ for (unsigned i = 0 ; i < shape.size (); i++) {
1615
+ if (shape[i] == 1 && !VT.getScalableDims ()[i]) {
1616
+ newVT = VectorType::Builder (newVT).dropDim (i - removed);
1617
+ removed++;
1618
+ }
1619
+ }
1620
+
1621
+ if (removed == 0 )
1622
+ return failure ();
1623
+ return newVT;
1624
+ }
1625
+
1626
+
1627
+ // / For vectors with at least an unit dim, replaces:
1611
1628
// / elementwise(a, b)
1612
1629
// / with:
1613
1630
// / sc_a = shape_cast(a)
@@ -1641,7 +1658,9 @@ struct DropUnitDimFromElementwiseOps final
1641
1658
using OpTraitRewritePattern::OpTraitRewritePattern;
1642
1659
LogicalResult matchAndRewrite (Operation *op,
1643
1660
PatternRewriter &rewriter) const override {
1644
- if (op->getNumResults () != 1 || op->getNumRegions () != 0 )
1661
+ if (op->getNumResults () != 1 )
1662
+ return failure ();
1663
+ if (op->getNumRegions () != 0 )
1645
1664
return failure ();
1646
1665
1647
1666
auto resultVectorType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
@@ -1652,42 +1671,30 @@ struct DropUnitDimFromElementwiseOps final
1652
1671
// guaranteed to have identical shapes (with some exceptions such as
1653
1672
// `arith.select`) and it suffices to only check one of them.
1654
1673
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand (0 ).getType ());
1655
- if (!sourceVectorType)
1656
- return failure ();
1657
- if (sourceVectorType.getRank () < 2 )
1658
- return failure ();
1659
-
1660
- bool hasTrailingDimUnitFixed =
1661
- ((sourceVectorType.getShape ().back () == 1 ) &&
1662
- (!sourceVectorType.getScalableDims ().back ()));
1663
- bool hasLeadingDimUnitFixed =
1664
- ((sourceVectorType.getShape ().front () == 1 ) &&
1665
- (!sourceVectorType.getScalableDims ().front ()));
1666
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1674
+ if (!sourceVectorType || sourceVectorType.getRank () < 2 )
1667
1675
return failure ();
1668
1676
1669
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
1670
- // operands
1671
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank () - 1 ;
1672
-
1673
1677
SmallVector<Value> newOperands;
1674
1678
auto loc = op->getLoc ();
1675
1679
for (auto operand : op->getOperands ()) {
1676
1680
auto opVectorType = cast<VectorType>(operand.getType ());
1677
- VectorType newVType = VectorType::Builder (opVectorType).dropDim (dim);
1678
- auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
1681
+ auto newVType = dropNonScalableUnitDimType (opVectorType);
1682
+ if (failed (newVType)) {
1683
+ return failure ();
1684
+ }
1685
+ auto opSC =
1686
+ rewriter.create <vector::ShapeCastOp>(loc, newVType.value (), operand);
1679
1687
newOperands.push_back (opSC);
1680
1688
}
1681
1689
1682
1690
VectorType newResultVectorType =
1683
- VectorType::Builder (resultVectorType).dropDim (dim );
1684
- // Create an updated elementwise Op without leading/trailing unit dim
1691
+ dropNonScalableUnitDimType (resultVectorType).value ( );
1692
+ // Create an updated elementwise Op without unit dim
1685
1693
Operation *elementwiseOp =
1686
1694
rewriter.create (loc, op->getName ().getIdentifier (), newOperands,
1687
1695
newResultVectorType, op->getAttrs ());
1688
1696
1689
- // Restore the leading/trailing unit dim by applying vector.shape_cast
1690
- // to the result
1697
+ // Restore the unit dim by applying vector.shape_cast to the result
1691
1698
rewriter.replaceOpWithNewOp <ShapeCastOp>(op, resultVectorType,
1692
1699
elementwiseOp->getResult (0 ));
1693
1700
0 commit comments