@@ -40,22 +40,43 @@ using llvm::SetVector;
4040const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
4141 " __internal_linalg_transform__" ;
4242
43- LogicalResult mlir::linalg::tileLinalgOpAndSetMarker (
44- PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
45- StringRef linalgMarker, ArrayRef<unsigned > permutation) {
43+ using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t >,
44+ ArrayRef<unsigned >, OperationFolder *);
45+
46+ static LogicalResult
47+ tileLinalgOpAndSetMarkerImpl (TileFn tileFn, PatternRewriter &rewriter,
48+ Operation *op, ArrayRef<int64_t > sizes,
49+ StringRef linalgMarker,
50+ ArrayRef<unsigned > permutation) {
4651 assert (permutation.empty () || permutation.size () == sizes.size ());
47- auto tileRes = tileLinalgOperation (rewriter, op, sizes, permutation);
52+ auto tileRes = tileFn (rewriter, op, sizes, permutation, /* folder= */ nullptr );
4853 if (!tileRes)
4954 return failure ();
5055 tileRes->op .setAttr (LinalgTransforms::kLinalgTransformMarker ,
5156 rewriter.getStringAttr (linalgMarker));
5257 return success ();
5358}
5459
55- LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker (
60+ LogicalResult mlir::linalg::tileLinalgOpAndSetMarker (
5661 PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
57- ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
58- auto tileRes = tileLinalgOperation (rewriter, op, sizes);
62+ StringRef linalgMarker, ArrayRef<unsigned > permutation) {
63+ return tileLinalgOpAndSetMarkerImpl (tileLinalgOp, rewriter, op, sizes,
64+ linalgMarker, permutation);
65+ }
66+ LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker (
67+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
68+ StringRef linalgMarker, ArrayRef<unsigned > permutation) {
69+ return tileLinalgOpAndSetMarkerImpl (tileLinalgOpToParallelLoops, rewriter, op,
70+ sizes, linalgMarker, permutation);
71+ }
72+
73+ static LogicalResult
74+ tileAndFuseLinalgOpAndSetMarkerImpl (TileFn tileFn, PatternRewriter &rewriter,
75+ Operation *op, ArrayRef<int64_t > sizes,
76+ ArrayRef<int64_t > operandIndicesToFuse,
77+ StringRef linalgMarker) {
78+ auto tileRes =
79+ tileFn (rewriter, op, sizes, /* permutation=*/ {}, /* folder=*/ nullptr );
5980 if (!tileRes)
6081 return failure ();
6182 tileRes->op .setAttr (LinalgTransforms::kLinalgTransformMarker ,
@@ -89,6 +110,20 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
89110 return success ();
90111}
91112
113+ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker (
114+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
115+ ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
116+ return tileAndFuseLinalgOpAndSetMarkerImpl (
117+ tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
118+ }
119+ LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker (
120+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t > sizes,
121+ ArrayRef<int64_t > operandIndicesToFuse, StringRef linalgMarker) {
122+ return tileAndFuseLinalgOpAndSetMarkerImpl (
123+ tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
124+ linalgMarker);
125+ }
126+
92127bool mlir::linalg::detail::isProducedByOpOfTypeImpl (
93128 Operation *consumerOp, Value consumedView,
94129 function_ref<bool (Operation *)> isaOpType) {
0 commit comments