@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48374837 // Already a constant
48384838 newMixedTileSizes.push_back (std::get<1 >(it));
48394839 } else {
4840- int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4841- assert (tileSize == shape && " tile size and dim size don't match!" );
4842- (void )tileSize;
4840+ assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4841+ " tile size and dim size don't match!" );
48434842 newMixedTileSizes.push_back (
48444843 (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
48454844 }
48464845 }
48474846
48484847 // Clone op.
4848+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4849+ // this point. However, in practice, we use them for things that we'd like
4850+ // to preserve. Implement a better abstraction.
48494851 PackOp newOp = rewriter.create <PackOp>(
48504852 op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
48514853 newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48654867 }
48664868};
48674869
4870+ // / Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4871+ // / `tensor.cast` has source that is more static than the consuming op.
4872+ // /
4873+ // / Example:
4874+ // / ```mlir
4875+ // / %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4876+ // / %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4877+ // / ```
4878+ // /
4879+ // / folds into:
4880+ // /
4881+ // / ```mlir
4882+ // / %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4883+ // / ```
4884+ struct FoldTensorCastUnPackOp : public OpRewritePattern <UnPackOp> {
4885+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
4886+
4887+ LogicalResult matchAndRewrite (UnPackOp op,
4888+ PatternRewriter &rewriter) const override {
4889+ if (!foldTensorCastPrecondition (op))
4890+ return failure ();
4891+
4892+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4893+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4894+ Value sourceTensor = newOperands[0 ];
4895+
4896+ // Get the updated mixed-tile-sizes attribute.
4897+ SmallVector<OpFoldResult> newMixedTileSizes;
4898+ for (auto it : llvm::zip (cast<ShapedType>(sourceTensor.getType ())
4899+ .getShape ()
4900+ .take_back (op.getMixedTiles ().size ()),
4901+ op.getMixedTiles ())) {
4902+ int64_t shape = std::get<0 >(it);
4903+ // If the current source shape is dynamic, just preserve this mixed
4904+ // size.
4905+ if (shape == ShapedType::kDynamic ) {
4906+ newMixedTileSizes.push_back (std::get<1 >(it));
4907+ continue ;
4908+ }
4909+
4910+ // If the current source is static, update the dynamic mixed-size
4911+ // (provided the original value is dynamic).
4912+ if (Attribute attr =
4913+ llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4914+ // Already a constant
4915+ newMixedTileSizes.push_back (std::get<1 >(it));
4916+ } else {
4917+ assert (getConstantIntValue (std::get<1 >(it)).value () == shape &&
4918+ " tile size and dim size don't match!" );
4919+ newMixedTileSizes.push_back (
4920+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4921+ }
4922+ }
4923+
4924+ // Clone op.
4925+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4926+ // this point. However, in practice, we use them for things that we'd like
4927+ // to preserve. Implement a better abstraction.
4928+ UnPackOp newOp = rewriter.create <UnPackOp>(
4929+ op.getLoc (), sourceTensor, newOperands[1 ], op.getInnerDimsPos (),
4930+ newMixedTileSizes, op.getOuterDimsPerm ());
4931+ newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
4932+
4933+ // Replace op.
4934+ Value oldResult = op.getResult ();
4935+ Value newResult = newOp.getResult ();
4936+ Value replacement = (newResult.getType () != oldResult.getType ())
4937+ ? rewriter.create <tensor::CastOp>(
4938+ op->getLoc (), oldResult.getType (), newResult)
4939+ : newResult;
4940+
4941+ rewriter.replaceOp (op, {replacement});
4942+
4943+ return success ();
4944+ }
4945+ };
4946+
48684947// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
48694948// / the `tensor.cast` has source that is more static than the consuming op.
48704949// /
@@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp
48904969 PatternRewriter &rewriter) const override {
48914970
48924971 // Reject tensor::PackOp - there's dedicated pattern for that instead.
4893- if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
4972+ if (!foldTensorCastPrecondition (op) ||
4973+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
48944974 return failure ();
48954975
48964976 SmallVector<Type> newResultTypes (op->getResultTypes ());
@@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp
49235003void TensorDialect::getCanonicalizationPatterns (
49245004 RewritePatternSet &results) const {
49255005 results.add <FoldTensorCastPackOp>(getContext ());
5006+ results.add <FoldTensorCastUnPackOp>(getContext ());
49265007 results.add <FoldTensorCastProducerOp>(getContext ());
49275008}
49285009
0 commit comments