@@ -230,6 +230,8 @@ class IRRewrite {
230230 // / Erase the given block (unless it was already erased).
231231 void eraseBlock (Block *block);
232232
233+ const ConversionConfig &getConfig () const ;
234+
233235 const Kind kind;
234236 ConversionPatternRewriterImpl &rewriterImpl;
235237};
@@ -735,8 +737,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
735737namespace mlir {
736738namespace detail {
737739struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
738- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter)
739- : eraseRewriter(rewriter.getContext()) {}
740+ explicit ConversionPatternRewriterImpl (MLIRContext *ctx,
741+ const ConversionConfig &config)
742+ : eraseRewriter(ctx), config(config) {}
740743
741744 // ===--------------------------------------------------------------------===//
742745 // State Management
@@ -936,14 +939,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
936939 // / converting the arguments of blocks within that region.
937940 DenseMap<Region *, const TypeConverter *> regionToConverter;
938941
939- // / This allows the user to collect the match failure message.
940- function_ref<void (Diagnostic &)> notifyCallback;
941-
942- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
943- // / this is populated with ops found to be legalizable to the target.
944- // / When mode == OpConversionMode::Partial, this is populated with ops found
945- // / *not* to be legalizable to the target.
946- DenseSet<Operation *> *trackedOps = nullptr ;
942+ // / Dialect conversion configuration.
943+ const ConversionConfig &config;
947944
948945#ifndef NDEBUG
949946 // / A set of operations that have pending updates. This tracking isn't
@@ -966,6 +963,10 @@ void IRRewrite::eraseBlock(Block *block) {
966963 rewriterImpl.eraseRewriter .eraseBlock (block);
967964}
968965
966+ const ConversionConfig &IRRewrite::getConfig () const {
967+ return rewriterImpl.config ;
968+ }
969+
969970void BlockTypeConversionRewrite::commit () {
970971 // Process the remapping for each of the original arguments.
971972 for (auto [origArg, info] :
@@ -1085,8 +1086,8 @@ void ReplaceOperationRewrite::commit() {
10851086 if (Value newValue =
10861087 rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
10871088 result.replaceAllUsesWith (newValue);
1088- if (rewriterImpl. trackedOps )
1089- rewriterImpl. trackedOps ->erase (op);
1089+ if (getConfig (). unlegalizedOps )
1090+ getConfig (). unlegalizedOps ->erase (op);
10901091 // Do not erase the operation yet. It may still be referenced in `mapping`.
10911092 op->getBlock ()->getOperations ().remove (op);
10921093}
@@ -1514,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15141515 Diagnostic diag (loc, DiagnosticSeverity::Remark);
15151516 reasonCallback (diag);
15161517 logger.startLine () << " ** Failure : " << diag.str () << " \n " ;
1517- if (notifyCallback)
1518- notifyCallback (diag);
1518+ if (config. notifyCallback )
1519+ config. notifyCallback (diag);
15191520 });
15201521}
15211522
15221523// ===----------------------------------------------------------------------===//
15231524// ConversionPatternRewriter
15241525// ===----------------------------------------------------------------------===//
15251526
1526- ConversionPatternRewriter::ConversionPatternRewriter (MLIRContext *ctx)
1527+ ConversionPatternRewriter::ConversionPatternRewriter (
1528+ MLIRContext *ctx, const ConversionConfig &config)
15271529 : PatternRewriter(ctx),
1528- impl(new detail::ConversionPatternRewriterImpl(* this )) {
1530+ impl(new detail::ConversionPatternRewriterImpl(ctx, config )) {
15291531 setListener (impl.get ());
15301532}
15311533
@@ -1994,12 +1996,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
19941996 assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
19951997 LLVM_DEBUG ({
19961998 logFailure (rewriterImpl.logger , " pattern failed to match" );
1997- if (rewriterImpl.notifyCallback ) {
1999+ if (rewriterImpl.config . notifyCallback ) {
19982000 Diagnostic diag (op->getLoc (), DiagnosticSeverity::Remark);
19992001 diag << " Failed to apply pattern \" " << pattern.getDebugName ()
20002002 << " \" on op:\n "
20012003 << *op;
2002- rewriterImpl.notifyCallback (diag);
2004+ rewriterImpl.config . notifyCallback (diag);
20032005 }
20042006 });
20052007 rewriterImpl.resetState (curState);
@@ -2387,14 +2389,12 @@ namespace mlir {
23872389struct OperationConverter {
23882390 explicit OperationConverter (const ConversionTarget &target,
23892391 const FrozenRewritePatternSet &patterns,
2390- OpConversionMode mode ,
2391- DenseSet<Operation *> *trackedOps = nullptr )
2392- : opLegalizer(target, patterns), mode(mode ), trackedOps(trackedOps ) {}
2392+ const ConversionConfig &config ,
2393+ OpConversionMode mode )
2394+ : opLegalizer(target, patterns), config(config ), mode(mode ) {}
23932395
23942396 // / Converts the given operations to the conversion target.
2395- LogicalResult
2396- convertOperations (ArrayRef<Operation *> ops,
2397- function_ref<void (Diagnostic &)> notifyCallback = nullptr );
2397+ LogicalResult convertOperations (ArrayRef<Operation *> ops);
23982398
23992399private:
24002400 // / Converts an operation with the given rewriter.
@@ -2431,14 +2431,11 @@ struct OperationConverter {
24312431 // / The legalizer to use when converting operations.
24322432 OperationLegalizer opLegalizer;
24332433
2434+ // / Dialect conversion configuration.
2435+ ConversionConfig config;
2436+
24342437 // / The conversion mode to use when legalizing operations.
24352438 OpConversionMode mode;
2436-
2437- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2438- // / this is populated with ops found to be legalizable to the target.
2439- // / When mode == OpConversionMode::Partial, this is populated with ops found
2440- // / *not* to be legalizable to the target.
2441- DenseSet<Operation *> *trackedOps;
24422439};
24432440} // namespace mlir
24442441
@@ -2452,28 +2449,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
24522449 return op->emitError ()
24532450 << " failed to legalize operation '" << op->getName () << " '" ;
24542451 // Partial conversions allow conversions to fail iff the operation was not
2455- // explicitly marked as illegal. If the user provided a nonlegalizableOps
2456- // set, non-legalizable ops are included .
2452+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2453+ // set, non-legalizable ops are added to that set .
24572454 if (mode == OpConversionMode::Partial) {
24582455 if (opLegalizer.isIllegal (op))
24592456 return op->emitError ()
24602457 << " failed to legalize operation '" << op->getName ()
24612458 << " ' that was explicitly marked illegal" ;
2462- if (trackedOps )
2463- trackedOps ->insert (op);
2459+ if (config. unlegalizedOps )
2460+ config. unlegalizedOps ->insert (op);
24642461 }
24652462 } else if (mode == OpConversionMode::Analysis) {
24662463 // Analysis conversions don't fail if any operations fail to legalize,
24672464 // they are only interested in the operations that were successfully
24682465 // legalized.
2469- trackedOps->insert (op);
2466+ if (config.legalizableOps )
2467+ config.legalizableOps ->insert (op);
24702468 }
24712469 return success ();
24722470}
24732471
2474- LogicalResult OperationConverter::convertOperations (
2475- ArrayRef<Operation *> ops,
2476- function_ref<void (Diagnostic &)> notifyCallback) {
2472+ LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
24772473 if (ops.empty ())
24782474 return success ();
24792475 const ConversionTarget &target = opLegalizer.getTarget ();
@@ -2494,10 +2490,8 @@ LogicalResult OperationConverter::convertOperations(
24942490 }
24952491
24962492 // Convert each operation and discard rewrites on failure.
2497- ConversionPatternRewriter rewriter (ops.front ()->getContext ());
2493+ ConversionPatternRewriter rewriter (ops.front ()->getContext (), config );
24982494 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2499- rewriterImpl.notifyCallback = notifyCallback;
2500- rewriterImpl.trackedOps = trackedOps;
25012495
25022496 for (auto *op : toConvert)
25032497 if (failed (convert (rewriter, op)))
@@ -3484,57 +3478,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34843478// ===----------------------------------------------------------------------===//
34853479// Partial Conversion
34863480
3487- LogicalResult
3488- mlir::applyPartialConversion (ArrayRef<Operation *> ops,
3489- const ConversionTarget &target,
3490- const FrozenRewritePatternSet &patterns,
3491- DenseSet<Operation *> *unconvertedOps) {
3492- OperationConverter opConverter (target, patterns, OpConversionMode::Partial,
3493- unconvertedOps);
3481+ LogicalResult mlir::applyPartialConversion (
3482+ ArrayRef<Operation *> ops, const ConversionTarget &target,
3483+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3484+ OperationConverter opConverter (target, patterns, config,
3485+ OpConversionMode::Partial);
34943486 return opConverter.convertOperations (ops);
34953487}
34963488LogicalResult
34973489mlir::applyPartialConversion (Operation *op, const ConversionTarget &target,
34983490 const FrozenRewritePatternSet &patterns,
3499- DenseSet<Operation *> *unconvertedOps) {
3500- return applyPartialConversion (llvm::ArrayRef (op), target, patterns,
3501- unconvertedOps);
3491+ ConversionConfig config) {
3492+ return applyPartialConversion (llvm::ArrayRef (op), target, patterns, config);
35023493}
35033494
35043495// ===----------------------------------------------------------------------===//
35053496// Full Conversion
35063497
3507- LogicalResult
3508- mlir::applyFullConversion (ArrayRef<Operation *> ops,
3509- const ConversionTarget &target,
3510- const FrozenRewritePatternSet &patterns) {
3511- OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3498+ LogicalResult mlir::applyFullConversion (ArrayRef<Operation *> ops,
3499+ const ConversionTarget &target,
3500+ const FrozenRewritePatternSet &patterns,
3501+ ConversionConfig config) {
3502+ OperationConverter opConverter (target, patterns, config,
3503+ OpConversionMode::Full);
35123504 return opConverter.convertOperations (ops);
35133505}
3514- LogicalResult
3515- mlir::applyFullConversion (Operation *op, const ConversionTarget &target,
3516- const FrozenRewritePatternSet &patterns) {
3517- return applyFullConversion (llvm::ArrayRef (op), target, patterns);
3506+ LogicalResult mlir::applyFullConversion (Operation *op,
3507+ const ConversionTarget &target,
3508+ const FrozenRewritePatternSet &patterns,
3509+ ConversionConfig config) {
3510+ return applyFullConversion (llvm::ArrayRef (op), target, patterns, config);
35183511}
35193512
35203513// ===----------------------------------------------------------------------===//
35213514// Analysis Conversion
35223515
3523- LogicalResult
3524- mlir::applyAnalysisConversion (ArrayRef<Operation *> ops,
3525- ConversionTarget &target,
3526- const FrozenRewritePatternSet &patterns,
3527- DenseSet<Operation *> &convertedOps,
3528- function_ref<void (Diagnostic &)> notifyCallback) {
3529- OperationConverter opConverter (target, patterns, OpConversionMode::Analysis,
3530- &convertedOps);
3531- return opConverter.convertOperations (ops, notifyCallback);
3516+ LogicalResult mlir::applyAnalysisConversion (
3517+ ArrayRef<Operation *> ops, ConversionTarget &target,
3518+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3519+ OperationConverter opConverter (target, patterns, config,
3520+ OpConversionMode::Analysis);
3521+ return opConverter.convertOperations (ops);
35323522}
35333523LogicalResult
35343524mlir::applyAnalysisConversion (Operation *op, ConversionTarget &target,
35353525 const FrozenRewritePatternSet &patterns,
3536- DenseSet<Operation *> &convertedOps,
3537- function_ref<void (Diagnostic &)> notifyCallback) {
3538- return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns,
3539- convertedOps, notifyCallback);
3526+ ConversionConfig config) {
3527+ return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns, config);
35403528}
0 commit comments