@@ -59,6 +59,30 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
5959 ArrayRef<bool > inputVecScalableFlags = {},
6060 bool flatten1DDepthwiseConv = false );
6161
62+ // / Vectorize tensor::InsertSliceOp with:
63+ // / * vector::TransferReadOp + vector::TransferWriteOp
64+ // / The vector sizes are either:
65+ // / * user-provided in `inputVectorSizes`, or
66+ // / * inferred from the static dims in the input and output tensors.
67+ // / Bails out if:
68+ // / * vector sizes are not user-provided, and
69+ // / * at least one dim is dynamic (in both the input and output tensors),
70+ // / bails out.
71+ // /
72+ // / Before:
73+ // / !t_in_type = tensor<1x2x3xf32>
74+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
75+ // / !v_type = vector<1x2x3xf32>
76+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
77+ // / into !t_out_type
78+ // / After:
79+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
80+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
81+ static LogicalResult
82+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
83+ ArrayRef<int64_t > inputVectorSizes,
84+ SmallVectorImpl<Value> &newResults);
85+
6286// / Return the unique instance of OpType in `block` if it is indeed unique.
6387// / Return null if none or more than 1 instances exist.
6488template <typename OpType>
@@ -1557,6 +1581,7 @@ static LogicalResult
15571581vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
15581582 ArrayRef<int64_t > inputVectorSizes,
15591583 SmallVectorImpl<Value> &newResults) {
1584+ // TODO: Introduce a parent class that will handle the insertion point update.
15601585 OpBuilder::InsertionGuard g (rewriter);
15611586 rewriter.setInsertionPoint (packOp);
15621587
@@ -1633,6 +1658,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16331658 ArrayRef<int64_t > inputVectorSizes,
16341659 SmallVectorImpl<Value> &newResults) {
16351660
1661+ // TODO: Introduce a parent class that will handle the insertion point update.
16361662 OpBuilder::InsertionGuard g (rewriter);
16371663 rewriter.setInsertionPoint (unpackOp);
16381664
@@ -1763,7 +1789,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
17631789 auto padValue = padOp.getConstantPaddingValue ();
17641790 Location loc = padOp.getLoc ();
17651791
1766- // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1792+ // TODO: Introduce a parent class that will handle the insertion point update.
17671793 OpBuilder::InsertionGuard g (rewriter);
17681794 rewriter.setInsertionPoint (padOp);
17691795
@@ -1874,6 +1900,15 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18741900 return success ();
18751901}
18761902
1903+ // / Need to check if the inner-tiles are static/constant.
1904+ static LogicalResult
1905+ vectorizeInsertSliceOpPrecondition (tensor::InsertSliceOp sliceOp,
1906+ ArrayRef<int64_t > inputVectorSizes) {
1907+
1908+ // TODO: Move pre-conditions from the vectorization logic
1909+ return success ();
1910+ }
1911+
18771912static LogicalResult vectorizeLinalgOpPrecondition (
18781913 LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
18791914 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2179,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
21442179 .Case <tensor::UnPackOp>([&](auto unpackOp) {
21452180 return vectorizeUnPackOpPrecondition (unpackOp, inputVectorSizes);
21462181 })
2182+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2183+ return vectorizeInsertSliceOpPrecondition (sliceOp, inputVectorSizes);
2184+ })
21472185 .Default ([](auto ) { return failure (); });
21482186}
21492187
@@ -2163,8 +2201,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
21632201}
21642202
21652203bool mlir::linalg::hasVectorizationImpl (Operation *op) {
2166- return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167- op);
2204+ return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2205+ tensor::InsertSliceOp>( op);
21682206}
21692207
21702208// / Emit a suitable vector form for an operation. If provided,
@@ -2178,6 +2216,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
21782216 ArrayRef<bool > inputScalableVecDims,
21792217 bool vectorizeNDExtract,
21802218 bool flatten1DDepthwiseConv) {
2219+ rewriter.getInsertionPoint ();
21812220 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
21822221 LDBG (" Input vector sizes: " );
21832222 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2244,6 +2283,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
22442283 return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
22452284 results);
22462285 })
2286+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2287+ return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
2288+ results);
2289+ })
22472290 .Case <tensor::UnPackOp>([&](auto unpackOp) {
22482291 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
22492292 inputVectorSizes, results);
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
25832626 return {};
25842627}
25852628
2586- // / Rewrite tensor.insert.slice as a vector.transfer_read +
2587- // / vector.transfer_write pair. The vector size is inferred from the static
2588- // / dims in the input and output tensors. If a dim is dynamic in both the input
2589- // / and output tensors, bails out.
2590- // /
2591- // / Before:
2592- // / !t_in_type = tensor<1x2x3xf32>
2593- // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2594- // / !v_type = vector<1x2x3xf32>
2595- // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596- // / into !t_out_type
2597- // / After:
2598- // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599- // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600- // /
2601- // / TODO: Support masking
2602- struct InsertSliceVectorizePattern
2603- : public OpRewritePattern<tensor::InsertSliceOp> {
2604- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2629+ static LogicalResult
2630+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2631+ ArrayRef<int64_t > inputVectorSizes,
2632+ SmallVectorImpl<Value> &newResults) {
2633+ // TODO: Introduce a parent class that will handle the insertion point update.
2634+ OpBuilder::InsertionGuard g (rewriter);
2635+ rewriter.setInsertionPoint (sliceOp);
26052636
2606- LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2607- PatternRewriter &rewriter) const final {
2608- auto sourceType = sliceOp.getSource ().getType ();
2609- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2610- return failure ();
2637+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
2638+ auto sourceType = source.getType ();
2639+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2640+ return failure ();
26112641
2612- auto resultType = sliceOp.getResultType ();
2613-
2614- // 1. Get the pad value.
2615- // TransferReadOp requires a scalar padding value. Note that:
2616- // * for in-bounds access, the value is actually irrelevant.
2617- // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618- // 1. The source shape is static (output vector sizes would be based on
2619- // the source shape and hence all memory accesses would be in-bounds),
2620- // 2. Masking is used (output vector sizes would be user-provided, in which
2621- // case it is assumed that all memory accesses are in-bounds). This
2622- // remains a TODO.
2623- //
2624- // When the value is not known and not needed, use 0. Otherwise, bail out.
2625- Value padValue = getStaticPadVal (sliceOp);
2626- bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2627-
2628- if (!padValue && isOutOfBoundsRead) {
2629- LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2642+ auto resultType = sliceOp.getResultType ();
2643+
2644+ // 1. Get the pad value.
2645+ // TransferReadOp requires a scalar padding value. Note that:
2646+ // * for in-bounds access, the value is actually irrelevant.
2647+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2648+ // 1. The source shape is static (output vector sizes would be based on
2649+ // the source shape and hence all memory accesses would be in-bounds),
2650+ // 2. Masking is used (output vector sizes would be user-provided, in which
2651+ // case it is assumed that all memory accesses are in-bounds). This
2652+ // remains a TODO.
2653+ //
2654+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2655+ Value padValue = getStaticPadVal (sliceOp);
2656+ bool isOutOfBoundsRead =
2657+ !sourceType.hasStaticShape () && inputVectorSizes.empty ();
2658+
2659+ if (!padValue && isOutOfBoundsRead) {
2660+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2661+ return failure ();
2662+ }
2663+
2664+ if (!padValue) {
2665+ auto elemType = sourceType.getElementType ();
2666+ padValue = rewriter.create <arith::ConstantOp>(
2667+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2668+ }
2669+
2670+ // 2. Get the vector shape and in-bounds attributes
2671+ SmallVector<int64_t > vecShape;
2672+ SmallVector<bool > readInBounds;
2673+ SmallVector<bool > writeInBounds;
2674+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2675+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2676+ if (!inputVectorSizes.empty ()) {
2677+ vecShape.push_back (inputVectorSizes[i]);
2678+ readInBounds.push_back (false );
2679+ writeInBounds.push_back (false );
2680+ } else if (!sourceType.isDynamicDim (i)) {
2681+ vecShape.push_back (sourceType.getDimSize (i));
2682+ // Source shape is statically known: Neither read nor write are
2683+ // out-of-bounds.
2684+ readInBounds.push_back (true );
2685+ writeInBounds.push_back (true );
2686+ } else if (!resultType.isDynamicDim (i)) {
2687+ // Source shape is not statically known, but result shape is.
2688+ // Vectorize with size of result shape. This may be larger than the
2689+ // source size.
2690+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2691+ // the end of the destination tensor. However, that's not required.
2692+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2693+ // Read may be out-of-bounds because the result size could be larger
2694+ // than the source size.
2695+ readInBounds.push_back (false );
2696+ // Write will be in-bounds provided that the corresponding write idx is 0.
2697+ // To keep this logic simple, conservatively mark as out-of-bounds.
2698+ writeInBounds.push_back (false );
2699+ } else {
2700+ // Neither source nor result dim of padOp is static. Cannot vectorize
2701+ // the copy.
2702+ // TODO: Add support for masking
26302703 return failure ();
26312704 }
2705+ }
2706+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
26322707
2633- if (!padValue) {
2634- auto elemType = sourceType.getElementType ();
2635- padValue = rewriter.create <arith::ConstantOp>(
2636- sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2637- }
2708+ // 3. Generate TransferReadOp.
2709+ SmallVector<Value> readIndices (
2710+ vecType.getRank (),
2711+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2712+ Operation *read = rewriter.create <vector::TransferReadOp>(
2713+ sliceOp.getLoc (), vecType, source, readIndices, padValue,
2714+ ArrayRef<bool >{readInBounds});
26382715
2639- // 2. Get the vector shape and in-bounds attributes
2640- SmallVector<int64_t > vecShape;
2641- SmallVector<bool > readInBounds;
2642- SmallVector<bool > writeInBounds;
2643- size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2644- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2645- if (!sourceType.isDynamicDim (i)) {
2646- vecShape.push_back (sourceType.getDimSize (i));
2647- // Source shape is statically known: Neither read nor write are
2648- // out-of-bounds.
2649- readInBounds.push_back (true );
2650- writeInBounds.push_back (true );
2651- } else if (!resultType.isDynamicDim (i)) {
2652- // Source shape is not statically known, but result shape is.
2653- // Vectorize with size of result shape. This may be larger than the
2654- // source size.
2655- // FIXME: Using rankDiff implies that the source tensor is inserted at
2656- // the end of the destination tensor. However, that's not required.
2657- vecShape.push_back (resultType.getDimSize (rankDiff + i));
2658- // Read may be out-of-bounds because the result size could be larger
2659- // than the source size.
2660- readInBounds.push_back (false );
2661- // Write will in-bounds provided that the corresponding write idx is 0.
2662- // To keep this logic simple, conservatively mark as out-of-bounds.
2663- writeInBounds.push_back (false );
2664- } else {
2665- // Neither source nor result dim of padOp is static. Cannot vectorize
2666- // the copy.
2667- // TODO: Add support for masking
2668- return failure ();
2669- }
2716+ // If vector sizes are user provided, make sure to mask xfer_read.
2717+ if (!inputVectorSizes.empty ()) {
2718+ auto *srcDefOp = source.getDefiningOp ();
2719+ if (!srcDefOp) {
2720+ LDBG (" Unable to get the defining Op of " << sliceOp);
2721+ return failure ();
26702722 }
2671- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
26722723
2673- // 3. Generate TransferReadOp.
2674- SmallVector<Value> readIndices (
2675- vecType.getRank (),
2676- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2677- auto read = rewriter.create <vector::TransferReadOp>(
2678- sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2679- ArrayRef<bool >{readInBounds});
2724+ ReifiedRankedShapedTypeDims reifiedSrcSizes;
2725+ LogicalResult status =
2726+ cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2727+ rewriter, reifiedSrcSizes);
2728+ if (status.failed ()) {
2729+ LDBG (" Unable to reify result shapes of " << sliceOp);
2730+ return failure ();
2731+ }
26802732
2681- // 4. Generate TransferWriteOp.
2682- auto writeIndices = getValueOrCreateConstantIndexOp (
2683- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2733+ // Create the mask
2734+ SmallVector<int64_t > readMaskShape (
2735+ sliceOp.getSource ().getType ().getShape ());
2736+ auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2737+ Value maskOp = rewriter.create <vector::CreateMaskOp>(
2738+ sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
26842739
2685- // 5. Finalize
2686- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2687- sliceOp, read, sliceOp.getDest (), writeIndices,
2688- ArrayRef<bool >{writeInBounds});
2740+ // Mask the xfer_read Op
2741+ read = mlir::vector::maskOperation (rewriter, read, maskOp);
2742+ }
26892743
2690- return success ();
2744+ // 4. Generate TransferWriteOp.
2745+ if (!inputVectorSizes.empty () &&
2746+ ShapedType::isDynamicShape (resultType.getShape ())) {
2747+ LDBG (" TODO: Masking of xfer_write when vectorising " << sliceOp);
2748+ return failure ();
26912749 }
2692- };
2750+
2751+ auto writeIndices = getValueOrCreateConstantIndexOp (
2752+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2753+
2754+ // 5. Finalize
2755+ Operation *write = rewriter.create <vector::TransferWriteOp>(
2756+ sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2757+ ArrayRef<bool >{writeInBounds});
2758+ newResults.push_back (write->getResult (0 ));
2759+
2760+ return success ();
2761+ }
26932762
26942763// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26952764// / ```
@@ -2778,11 +2847,6 @@ struct PadOpVectorizationWithInsertSlicePattern
27782847 }
27792848};
27802849
2781- void mlir::linalg::populateInsertSliceVectorizationPatterns (
2782- RewritePatternSet &patterns) {
2783- patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2784- }
2785-
27862850void mlir::linalg::populatePadOpVectorizationPatterns (
27872851 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27882852 patterns.add <PadOpVectorizationWithTransferReadPattern,
0 commit comments