@@ -27,8 +27,8 @@ namespace {
2727#include " ShapeCanonicalization.inc"
2828}
2929
30- RankedTensorType shape::getExtentTensorType (MLIRContext *ctx) {
31- return RankedTensorType::get ({ShapedType:: kDynamicSize }, IndexType::get (ctx));
30+ RankedTensorType shape::getExtentTensorType (MLIRContext *ctx, int64_t rank ) {
31+ return RankedTensorType::get ({rank }, IndexType::get (ctx));
3232}
3333
3434bool shape::isExtentTensorType (Type type) {
@@ -660,11 +660,42 @@ struct CanonicalizeCastExtentTensorOperandsPattern
660660 return success ();
661661 }
662662};
663+
664+ struct BroadcastConcretizeResultTypePattern
665+ : public OpRewritePattern<BroadcastOp> {
666+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
667+
668+ LogicalResult matchAndRewrite (BroadcastOp op,
669+ PatternRewriter &rewriter) const override {
670+ // Only concretize dynamic extent tensor result types.
671+ auto resultTy = op.getType ().dyn_cast <RankedTensorType>();
672+ if (!resultTy || !resultTy.isDynamicDim (0 ))
673+ return failure ();
674+
675+ // Infer resulting shape rank if possible.
676+ int64_t maxRank = 0 ;
677+ for (Value shape : op.shapes ()) {
678+ if (auto extentTensorTy = shape.getType ().dyn_cast <RankedTensorType>()) {
679+ // Cannot infer resulting shape rank if any operand is dynamically
680+ // ranked.
681+ if (extentTensorTy.isDynamicDim (0 ))
682+ return failure ();
683+ maxRank = std::max (maxRank, extentTensorTy.getDimSize (0 ));
684+ }
685+ }
686+
687+ auto newOp = rewriter.create <BroadcastOp>(
688+ op.getLoc (), getExtentTensorType (getContext (), maxRank), op.shapes ());
689+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, op.getType (), newOp);
690+ return success ();
691+ }
692+ };
663693} // namespace
664694
665695void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
666696 MLIRContext *context) {
667- patterns.add <BroadcastFoldConstantOperandsPattern,
697+ patterns.add <BroadcastConcretizeResultTypePattern,
698+ BroadcastFoldConstantOperandsPattern,
668699 BroadcastForwardSingleOperandPattern,
669700 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
670701 RemoveDuplicateOperandsPattern<BroadcastOp>,
0 commit comments