@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
39833983 op.getMixedTiles ());
39843984}
39853985
3986+ // / Returns true if the `srcShape` or `destShape` is different from the one in
3987+ // / `packOp` and populates each with the inferred static shape.
3988+ static bool inferStaticShape (PackOp packOp, SmallVectorImpl<int64_t > &srcShape,
3989+ SmallVectorImpl<int64_t > &destShape) {
3990+ bool changeNeeded = false ;
3991+ srcShape.assign (packOp.getSourceType ().getShape ().begin (),
3992+ packOp.getSourceType ().getShape ().end ());
3993+ destShape.assign (packOp.getDestType ().getShape ().begin (),
3994+ packOp.getDestType ().getShape ().end ());
3995+ llvm::SmallSetVector<int64_t , 4 > innerDims;
3996+ innerDims.insert (packOp.getInnerDimsPos ().begin (),
3997+ packOp.getInnerDimsPos ().end ());
3998+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
3999+ int srcRank = packOp.getSourceRank ();
4000+ for (auto i : llvm::seq<int64_t >(0 , srcRank)) {
4001+ if (innerDims.contains (i))
4002+ continue ;
4003+ int64_t srcPos = i;
4004+ int64_t destPos = i;
4005+ if (!outerDimsPerm.empty ())
4006+ destPos = outerDimsPerm[srcPos];
4007+ if (ShapedType::isDynamic (srcShape[srcPos]) ==
4008+ ShapedType::isDynamic (destShape[destPos])) {
4009+ continue ;
4010+ }
4011+ int64_t size = srcShape[srcPos];
4012+ if (ShapedType::isDynamic (size))
4013+ size = destShape[destPos];
4014+ srcShape[srcPos] = size;
4015+ destShape[destPos] = size;
4016+ changeNeeded = true ;
4017+ }
4018+ return changeNeeded;
4019+ }
4020+
39864021LogicalResult PackOp::canonicalize (PackOp packOp, PatternRewriter &rewriter) {
39874022 // Fold an unpack(pack(x)) to x.
39884023 if (auto unPackOp = packOp.getSource ().getDefiningOp <UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
40034038 rewriter.finalizeOpModification (packOp);
40044039 return success ();
40054040 }
4041+
4042+ // Insert tensor.cast ops if static shape inference is available..
4043+ SmallVector<int64_t > srcShape, destShape;
4044+ if (inferStaticShape (packOp, srcShape, destShape)) {
4045+ Location loc = packOp.getLoc ();
4046+ Value source = packOp.getSource ();
4047+ if (srcShape != packOp.getSourceType ().getShape ()) {
4048+ auto newSrcType = packOp.getSourceType ().clone (srcShape);
4049+ source =
4050+ rewriter.create <tensor::CastOp>(loc, newSrcType, packOp.getSource ());
4051+ }
4052+ Value dest = packOp.getDest ();
4053+ if (destShape != packOp.getDestType ().getShape ()) {
4054+ auto newDestType = packOp.getDestType ().clone (destShape);
4055+ dest =
4056+ rewriter.create <tensor::CastOp>(loc, newDestType, packOp.getDest ());
4057+ }
4058+ Value newOp = rewriter.create <tensor::PackOp>(
4059+ loc, source, dest, packOp.getInnerDimsPos (), packOp.getMixedTiles (),
4060+ packOp.getPaddingValue (), packOp.getOuterDimsPerm ());
4061+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
4062+ packOp, packOp.getResult ().getType (), newOp);
4063+ return success ();
4064+ }
4065+
40064066 return failure ();
40074067}
40084068
0 commit comments