@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
47954795 return newOperands;
47964796}
47974797
4798+ // Given the (potentially) updated packed type, `newPackedTy`, generates an
4799+ // updated mixed-tile-sizes attribute. A tile size is updated only
4800+ // when:
4801+ // * a dim from newPackedTy is static, and
4802+ // * the corresponding size from mixedTiles is still dynamic.
4803+ // Otherwise, the original tile size is preserved.
4804+ // Note - packed-type-dim and mixed-tile-size should always match!
4805+ static SmallVector<OpFoldResult>
4806+ getNewMixedTileSizes (PatternRewriter &rewriter, Type newPackedTy,
4807+ SmallVector<OpFoldResult> mixedTiles) {
4808+ SmallVector<OpFoldResult> newMixedTileSizes;
4809+ for (auto it : llvm::zip (cast<ShapedType>(newPackedTy)
4810+ .getShape ()
4811+ .take_back (mixedTiles.size ()),
4812+ mixedTiles)) {
4813+ int64_t shape = std::get<0 >(it);
4814+ if (shape == ShapedType::kDynamic ) {
4815+ newMixedTileSizes.push_back (std::get<1 >(it));
4816+ continue ;
4817+ }
4818+
4819+ // If the current result dim is static, update the dynamic mixed-size
4820+ // (provided the original value is dynamic).
4821+ OpFoldResult tile = std::get<1 >(it);
4822+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4823+ // Already a constant
4824+ newMixedTileSizes.push_back (tile);
4825+ } else {
4826+ assert (getConstantIntValue (tile).value () == shape &&
4827+ " tile size and dim size don't match!" );
4828+ newMixedTileSizes.push_back (
4829+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4830+ }
4831+ }
4832+
4833+ return newMixedTileSizes;
4834+ }
4835+
47984836// / Folds a tensor.cast op into a consuming tensor::PackOp op if the
47994837// / `tensor.cast` has source that is more static than the consuming op.
48004838// /
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48214859 SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
48224860
48234861 // Get the updated mixed-tile-sizes attribute.
4824- SmallVector<OpFoldResult> newMixedTileSizes;
4825- for (auto it : llvm::zip (cast<ShapedType>(newResultTypes[0 ])
4826- .getShape ()
4827- .take_back (op.getMixedTiles ().size ()),
4828- op.getMixedTiles ())) {
4829- int64_t shape = std::get<0 >(it);
4830- if (shape == ShapedType::kDynamic ) {
4831- newMixedTileSizes.push_back (std::get<1 >(it));
4832- continue ;
4833- }
4834-
4835- if (Attribute attr =
4836- llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4837- // Already a constant
4838- newMixedTileSizes.push_back (std::get<1 >(it));
4839- } else {
4840- int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4841- assert (tileSize == shape && " tile size and dim size don't match!" );
4842- (void )tileSize;
4843- newMixedTileSizes.push_back (
4844- (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4845- }
4846- }
4862+ SmallVector<OpFoldResult> newMixedTileSizes =
4863+ getNewMixedTileSizes (rewriter, newResultTypes[0 ], op.getMixedTiles ());
48474864
48484865 // Clone op.
4866+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4867+ // this point. However, in practice, we use them for things that we'd like
4868+ // to preserve. Implement a better abstraction.
48494869 PackOp newOp = rewriter.create <PackOp>(
48504870 op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
48514871 newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48654885 }
48664886};
48674887
4888+ // / Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4889+ // / `tensor.cast` has source that is more static than the consuming op.
4890+ // /
4891+ // / Example:
4892+ // / ```mlir
4893+ // / %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4894+ // / %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
4895+ // / ```
4896+ // /
4897+ // / folds into:
4898+ // /
4899+ // / ```mlir
4900+ // / %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4901+ // / ```
4902+ struct FoldTensorCastUnPackOp : public OpRewritePattern <UnPackOp> {
4903+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
4904+
4905+ LogicalResult matchAndRewrite (UnPackOp op,
4906+ PatternRewriter &rewriter) const override {
4907+ if (!foldTensorCastPrecondition (op))
4908+ return failure ();
4909+
4910+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4911+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4912+ Value sourceTensor = newOperands[0 ];
4913+
4914+ // Get the updated mixed-tile-sizes attribute.
4915+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes (
4916+ rewriter, sourceTensor.getType (), op.getMixedTiles ());
4917+
4918+ // Clone op.
4919+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4920+ // this point. However, in practice, we use them for things that we'd like
4921+ // to preserve. Implement a better abstraction.
4922+ UnPackOp newOp = rewriter.create <UnPackOp>(
4923+ op.getLoc (), sourceTensor, newOperands[1 ], op.getInnerDimsPos (),
4924+ newMixedTileSizes, op.getOuterDimsPerm ());
4925+ newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
4926+
4927+ // Replace op.
4928+ Value oldResult = op.getResult ();
4929+ Value newResult = newOp.getResult ();
4930+ Value replacement = (newResult.getType () != oldResult.getType ())
4931+ ? rewriter.create <tensor::CastOp>(
4932+ op->getLoc (), oldResult.getType (), newResult)
4933+ : newResult;
4934+
4935+ rewriter.replaceOp (op, {replacement});
4936+
4937+ return success ();
4938+ }
4939+ };
4940+
48684941// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
48694942// / the `tensor.cast` has source that is more static than the consuming op.
48704943// /
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
48904963 PatternRewriter &rewriter) const override {
48914964
48924965 // Reject tensor::PackOp - there's dedicated pattern for that instead.
4893- if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
4966+ if (!foldTensorCastPrecondition (op) ||
4967+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
48944968 return failure ();
48954969
48964970 SmallVector<Type> newResultTypes (op->getResultTypes ());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
49234997void TensorDialect::getCanonicalizationPatterns (
49244998 RewritePatternSet &results) const {
49254999 results.add <FoldTensorCastPackOp>(getContext ());
5000+ results.add <FoldTensorCastUnPackOp>(getContext ());
49265001 results.add <FoldTensorCastProducerOp>(getContext ());
49275002}
49285003
0 commit comments