@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10301030 return input;
10311031 }
10321032
1033+ assert (llvm::all_of (packOp.getAllOuterDims (),
1034+ [](int64_t val) { return val == 1 ; }) &&
1035+ " some outer dims are != 1" );
1036+
10331037 Location loc = packOp.getLoc ();
10341038 ShapedType inputType = packOp.getSourceType ();
10351039 int64_t inputRank = inputType.getRank ();
1036- assert (llvm::all_of (packOp.getDestType ().getShape ().take_front (inputRank),
1037- [](int64_t val) { return val == 1 ; }));
10381040
10391041 SmallVector<int64_t > paddedShape;
10401042 DenseMap<int64_t , OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11261128
11271129 // TODO: support the case that outer dimensions are not all 1s. A
11281130 // tensor.expand_shape will be generated in this case.
1129- auto innerDimsPos = packOp.getInnerDimsPos ();
1130- int64_t srcRank = packOp.getSourceRank ();
1131- auto destShape = packOp.getDestType ().getShape ();
1132- if (llvm::any_of (innerDimsPos, [destShape](int64_t index) {
1133- return destShape[index] != 1 ;
1134- })) {
1131+ if (llvm::any_of (packOp.getTiledOuterDims (),
1132+ [](int64_t dim) { return dim != 1 ; })) {
11351133 return rewriter.notifyMatchFailure (
11361134 packOp, " require the tiled outer dimensions of the result are all 1s" );
11371135 }
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11451143 packOp.getDimAndTileMapping ();
11461144 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
11471145 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
1146+ int64_t srcRank = packOp.getSourceRank ();
11481147 SmallVector<OpFoldResult> readOffsets (srcRank, zeroIdxAttr);
11491148 SmallVector<OpFoldResult> readStrides (srcRank, oneIdxAttr);
11501149 SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11731172 loc, readType, input, readOffsets, readSizes, readStrides);
11741173
11751174 // 2. Transpose the tile to match the inner tile order.
1176-
11771175 SmallVector<int64_t > perm = getPackUnpackRankReducedPerm (
1178- inputShape, innerDimsPos , packOp.getOuterDimsPerm ());
1176+ inputShape, packOp. getInnerDimsPos () , packOp.getOuterDimsPerm ());
11791177
11801178 LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
11811179 llvm::interleaveComma (perm, DBGS () << " perm: " ); DBGSNL (););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12081206 int64_t destRank = unpackOp.getDestRank ();
12091207 ArrayRef<int64_t > srcShape = unpackOp.getSourceType ().getShape ();
12101208 ArrayRef<int64_t > innerDimsPos = unpackOp.getInnerDimsPos ();
1211- if (llvm::any_of (innerDimsPos, [srcShape](int64_t index) {
1212- return srcShape[index] != 1 ;
1213- })) {
1209+ if (llvm::any_of (unpackOp.getTiledOuterDims (),
1210+ [](int64_t dim) { return dim != 1 ; })) {
12141211 return rewriter.notifyMatchFailure (
12151212 unpackOp,
12161213 " require the tiled outer dimensions of the result are all 1s" );
0 commit comments