@@ -1139,37 +1139,14 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11391139 return perm;
11401140}
11411141
1142- // A helper function to generate a dim-and-size pair for Ops like
1143- // ExtractSliceOp that require both:
1144- // * dims to specify the output shape, and
1145- // * sizes for the sizes attribute (or similar).
1146- // For dynamic sizes, if the corresponding size is a compile time constant:
1147- // * the return size becomes the attribute encapsulating the known size, and
1148- // * dim is updated from kDynamic to its actual known value.
1149- static std::pair<int64_t , OpFoldResult>
1150- getSimplifiedDimSizePair (OpFoldResult tileSizeOfr, Builder &b) {
1151- int64_t tileSizeForShape =
1152- getConstantIntValue (tileSizeOfr).value_or (ShapedType::kDynamic );
1153-
1154- OpFoldResult tileSizeOfrSimplified;
1155- if (tileSizeForShape != ShapedType::kDynamic ) {
1156- tileSizeOfrSimplified = b.getIndexAttr (tileSizeForShape);
1157- } else {
1158- tileSizeOfrSimplified = tileSizeOfr;
1159- }
1160-
1161- return std::pair<int64_t , OpFoldResult>(tileSizeForShape,
1162- tileSizeOfrSimplified);
1163- }
1164-
11651142LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite (
11661143 tensor::PackOp packOp, PatternRewriter &rewriter) const {
11671144 // TODO: support the case that outer dimensions are not all 1s. A
11681145 // tensor.expand_shape will be generated in this case.
1169- if (llvm::any_of (packOp.getTiledOuterDims (),
1146+ if (llvm::any_of (packOp.getAllOuterDims (),
11701147 [](int64_t dim) { return dim != 1 ; })) {
11711148 return rewriter.notifyMatchFailure (
1172- packOp, " require the tiled outer dimensions of the result are all 1s" );
1149+ packOp, " not all outer dimensions of the result are 1s" );
11731150 }
11741151
11751152 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
@@ -1202,7 +1179,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12021179 for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
12031180 if (dimAndTileMapping.count (i)) {
12041181 auto [tileSize, tileSizeOfr] =
1205- getSimplifiedDimSizePair (dimAndTileMapping[i], rewriter);
1182+ getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
12061183 extractSliceSizes.push_back (tileSizeOfr);
12071184 outputShapeForExtractSlice.push_back (tileSize);
12081185 }
@@ -1236,8 +1213,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12361213 }
12371214
12381215 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
1239- Value empty = rewriter.create <tensor::EmptyOp>(
1240- loc, transShapeForEmptyOpDynamic, elemType);
1216+ Value empty = rewriter.create <tensor::EmptyOp>(
1217+ loc, transShapeForEmptyOpDynamic, elemType);
12411218
12421219 // 2.2 Create linalg.transpose
12431220 auto transposedOp =
@@ -1254,7 +1231,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12541231
12551232 for (auto tileSize : packOp.getMixedTiles ()) {
12561233 auto [tileSizeStatic, tileSizeOfr] =
1257- getSimplifiedDimSizePair (tileSize, rewriter);
1234+ getSimplifiedOfrAndStaticSizePair (tileSize, rewriter);
12581235 writeSizes.push_back (tileSizeOfr);
12591236 writeShape.push_back (tileSizeStatic);
12601237 }
0 commit comments