From fd2fd347cb185722265f27b7a2aa8a32315e3cbf Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Mon, 15 Jul 2024 14:50:47 +0100 Subject: [PATCH 1/2] [mlir] Add RewriterBase to the C API --- mlir/include/mlir-c/Rewrite.h | 257 +++++++++++++ mlir/include/mlir/CAPI/Rewrite.h | 23 ++ mlir/lib/CAPI/Transforms/Rewrite.cpp | 249 ++++++++++++ mlir/test/CAPI/CMakeLists.txt | 9 + mlir/test/CAPI/rewrite.c | 551 +++++++++++++++++++++++++++ mlir/test/CMakeLists.txt | 1 + mlir/test/lit.cfg.py | 1 + 7 files changed, 1091 insertions(+) create mode 100644 mlir/include/mlir/CAPI/Rewrite.h create mode 100644 mlir/test/CAPI/rewrite.c diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index bed93045f4b50..09f8a72a0c599 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -33,10 +33,263 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirRewriterBase, void); DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +/// Get the MLIR context referenced by the rewriter. +MLIR_CAPI_EXPORTED MlirContext +mlirRewriterBaseGetContext(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Insertion points methods + +// They do not include functions using Block::iterator or Region::iterator, as +// they are not exposed by the C API yet. This includes methods using +// `InsertPoint` directly. + +/// Reset the insertion point to no location. Creating an operation without a +/// set insertion point is an error, but this can still be useful when the +/// current insertion point a builder refers to is being removed. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter); + +/// Sets the insertion point to the specified operation, which will cause +/// subsequent insertions to go right before it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified operation, which +/// will cause subsequent insertions to go right after it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified value. If value +/// has a defining operation, sets the insertion point to the node after such +/// defining operation. This will cause subsequent insertions to go right +/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to +/// the start of its block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value); + +/// Sets the insertion point to the start of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block); + +/// Sets the insertion point to the end of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block); + +/// Return the block the current insertion point belongs to. Note that the +/// insertion point is not necessarily the end of the block. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter); + +/// Returns the current block of the rewriter. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning + +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is placed before 'insertBefore'. `locs` contains the +/// locations of the inserted arguments, and should match the size of +/// `argTypes`. +MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore( + MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, + MlirType const *argTypes, MlirLocation const *locations); + +/// Insert the given operation at the current insertion point and return it. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op); + +// The IRMapper is not yet exposed in the CAPI +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op); + +// The IRMapper is not yet exposed in the CAPI +MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions( + MlirRewriterBase rewriter, MlirOperation op); + +// The IRMapper is not yet exposed in the CAPI, nor Region::iterator. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +/// Move the blocks that belong to "region" before the given position in +/// another region "parent". The two regions must be different. The caller +/// is responsible for creating or updating the operation transferring flow +/// of control to the region and passing it the correct block arguments. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +/// Replace the results of the given (original) operation with the specified +/// list of values (replacements). The result types of the given op and the +/// replacements must match. The original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, + intptr_t nValues, MlirValue const *values); + +/// Replace the results of the given (original) operation with the specified +/// new op (replacement). The result types of the two ops must match. The +/// original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, MlirOperation newOp); + +/// Erases an operation that is known to have no uses. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, + MlirOperation op); + +/// Erases a block along with all operations inside it. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, + MlirBlock block); + +/// Inline the operations of block 'source' before the operation 'op'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The source block must have no successors. Otherwise, the resulting IR +/// would have unreachable operations. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, + MlirOperation op, intptr_t nArgValues, + MlirValue const *argValues); + +/// Inline the operations of block 'source' into the end of block 'dest'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The dest block must have no successors. Otherwise, the resulting IR would +/// have unreachable operation. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, + MlirBlock source, + MlirBlock dest, + intptr_t nArgValues, + MlirValue const *argValues); + +// splitBlock is not implemented as Block::iterator is not exposed by the CAPI + +/// Unlink this operation from its current block and insert it right before +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this operation from its current block and insert it right after +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this block and insert it right before `existingBlock`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock); + +/// This method is used to notify the rewriter that an in-place operation +/// modification is about to happen. A call to this function *must* be +/// followed by a call to either `finalizeOpModification` or +/// `cancelOpModification`. This is a minor efficiency win (it avoids creating +/// a new operation and removing the old one) but also often allows simpler +/// code in the client. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method is used to signal the end of an in-place modification of the +/// given operation. This can only be called on operations that were provided +/// to a call to `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method cancels a pending in-place modification. This can only be +/// called on operations that were provided to a call to +/// `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, + MlirValue to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith( + MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, intptr_t nTo, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation( + MlirRewriterBase rewriter, MlirOperation from, MlirOperation to); + +/// Find uses of `from` within `block` and replace them with `to`. Also notify +/// the listener about every in-place op modification (for every use that was +/// replaced). The optional `allUsesReplaced` flag is set to "true" if all +/// uses were replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock( + MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, + MlirValue const *newValues, MlirBlock block); + +/// Find uses of `from` and replace them with `to` except if the user is +/// `exceptedUser`. Also notify the listener about every in-place op +/// modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, + MlirValue to, MlirOperation exceptedUser); + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +/// Create an IRRewriter and transfer ownership to the caller. +MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context); + +/// Create an IRRewriter and transfer ownership to the caller. Additionally +/// set the insertion point before the operation. +MLIR_CAPI_EXPORTED MlirRewriterBase +mlirIRRewriterCreateFromOp(MlirOperation op); + +/// Takes an IRRewriter owned by the caller and destroys it. It is the +/// responsibility of the user to only pass an IRRewriter class. +MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op); @@ -47,6 +300,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig); +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h new file mode 100644 index 0000000000000..0e6dcb2477626 --- /dev/null +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -0,0 +1,23 @@ +//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// rewrite patterns. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_REWRITE_H +#define MLIR_CAPI_REWRITE_H + +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" + +DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase); + +#endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 0de1958398f63..7f3c833df0910 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -7,15 +7,260 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Rewrite.h" + #include "mlir-c/Transforms.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getContext()); +} + +//===----------------------------------------------------------------------===// +/// Insertion points methods + +void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { + unwrap(rewriter)->clearInsertionPoint(); +} + +void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPoint(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value) { + unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); +} + +void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); +} + +void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); +} + +MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getInsertionBlock()); +} + +MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getBlock()); +} + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning + +MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, + MlirBlock insertBefore, + intptr_t nArgTypes, + MlirType const *argTypes, + MlirLocation const *locations) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgTypes, argTypes, args); + SmallVector locs; + ArrayRef unwrappedLocs = unwrapList(nArgTypes, locations, locs); + return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, + unwrappedLocs)); +} + +MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->insert(unwrap(op))); +} + +// Other methods of OpBuilder + +// The IRMapper is not yet exposed in the CAPI +MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->clone(*unwrap(op))); +} + +// The IRMapper is not yet exposed in the CAPI +MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); +} + +// The IRMapper is not yet exposed in the CAPI, nor Region::iterator. +void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + + unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); +} + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +// Region::iterator is not yet exposed in the CAPI. +void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); +} + +void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, + MlirOperation op, intptr_t nValues, + MlirValue const *values) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nValues, values, vals); + unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); +} + +void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation newOp) { + unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); +} + +void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { + unwrap(rewriter)->eraseOp(unwrap(op)); +} + +void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { + unwrap(rewriter)->eraseBlock(unwrap(block)); +} + +void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, + MlirBlock source, MlirOperation op, + intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nArgValues, argValues, vals); + + unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), + unwrappedVals); +} + +void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, + MlirBlock dest, intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgValues, argValues, args); + unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); +} + +// splitBlock is not implemented as Block::iterator is not exposed by the CAPI + +void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock) { + unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); +} + +void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->startOpModification(unwrap(op)); +} + +void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->finalizeOpModification(unwrap(op)); +} + +void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->cancelOpModification(unwrap(op)); +} + +void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, + MlirValue from, MlirValue to) { + unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, + intptr_t nValues, + MlirValue const *from, + MlirValue const *to) { + SmallVector fromVals; + ArrayRef unwrappedFromVals = unwrapList(nValues, from, fromVals); + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nValues, to, toVals); + unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, + intptr_t nTo, + MlirValue const *to) { + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nTo, to, toVals); + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, + MlirOperation from, + MlirOperation to) { + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, + MlirOperation op, + intptr_t nNewValues, + MlirValue const *newValues, + MlirBlock block) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nNewValues, newValues, vals); + unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, + unwrap(block)); +} + +void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, + MlirValue from, MlirValue to, + MlirOperation exceptedUser) { + unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), + unwrap(exceptedUser)); +} + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { + return wrap(new IRRewriter(unwrap(context))); +} + +MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { + return wrap(new IRRewriter(unwrap(op))); +} + +void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { + delete static_cast(unwrap(rewriter)); +} + +//===----------------------------------------------------------------------===// +/// RewritePatternSet and FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast(module.ptr)); @@ -54,6 +299,10 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op, mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); } +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt index ad312764b3e06..e795672bce5d1 100644 --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -89,6 +89,15 @@ _add_capi_test_executable(mlir-capi-quant-test MLIRCAPIQuant ) +_add_capi_test_executable(mlir-capi-rewrite-test + rewrite.c + LINK_LIBS PRIVATE + MLIRCAPIIR + MLIRCAPIRegisterEverything + MLIRCAPITransforms +) + + _add_capi_test_executable(mlir-capi-transform-test transform.c LINK_LIBS PRIVATE diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c new file mode 100644 index 0000000000000..a8b593eabb781 --- /dev/null +++ b/mlir/test/CAPI/rewrite.c @@ -0,0 +1,551 @@ +//===- rewrite.c - Test of the rewriting C API ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// RUN: mlir-capi-rewrite-test 2>&1 | FileCheck %s + +#include "mlir-c/Rewrite.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IR.h" + +#include +#include + +MlirOperation createOperationWithName(MlirContext ctx, const char *name) { + MlirStringRef nameRef = mlirStringRefCreateFromCString(name); + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirOperationState state = mlirOperationStateGet(nameRef, loc); + MlirType indexType = mlirIndexTypeGet(ctx); + mlirOperationStateAddResults(&state, 1, &indexType); + return mlirOperationCreate(&state); +} + +void testInsertionPoint(MlirContext ctx) { + // CHECK-LABEL: @testInsertionPoint + fprintf(stderr, "@testInsertionPoint\n"); + + const char *moduleString = "\"dialect.op1\"() : () -> ()\n"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + MlirOperation op1 = mlirBlockGetFirstOperation(body); + + // IRRewriter create + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + + // Insert before op + mlirRewriterBaseSetInsertionPointBefore(rewriter, op1); + MlirOperation op2 = createOperationWithName(ctx, "dialect.op2"); + mlirRewriterBaseInsert(rewriter, op2); + + // Insert after op + mlirRewriterBaseSetInsertionPointAfter(rewriter, op2); + MlirOperation op3 = createOperationWithName(ctx, "dialect.op3"); + mlirRewriterBaseInsert(rewriter, op3); + MlirValue op3Res = mlirOperationGetResult(op3, 0); + + // Insert after value + mlirRewriterBaseSetInsertionPointAfterValue(rewriter, op3Res); + MlirOperation op4 = createOperationWithName(ctx, "dialect.op4"); + mlirRewriterBaseInsert(rewriter, op4); + + // Insert at beginning of block + mlirRewriterBaseSetInsertionPointToStart(rewriter, body); + MlirOperation op5 = createOperationWithName(ctx, "dialect.op5"); + mlirRewriterBaseInsert(rewriter, op5); + + // Insert at end of block + mlirRewriterBaseSetInsertionPointToEnd(rewriter, body); + MlirOperation op6 = createOperationWithName(ctx, "dialect.op6"); + mlirRewriterBaseInsert(rewriter, op6); + + // Get insertion blocks + MlirBlock block1 = mlirRewriterBaseGetBlock(rewriter); + MlirBlock block2 = mlirRewriterBaseGetInsertionBlock(rewriter); + assert(body.ptr == block1.ptr); + assert(body.ptr == block2.ptr); + + // clang-format off + // CHECK-NEXT: module { + // CHECK-NEXT: %{{.*}} = "dialect.op5"() : () -> index + // CHECK-NEXT: %{{.*}} = "dialect.op2"() : () -> index + // CHECK-NEXT: %{{.*}} = "dialect.op3"() : () -> index + // CHECK-NEXT: %{{.*}} = "dialect.op4"() : () -> index + // CHECK-NEXT: "dialect.op1"() : () -> () + // CHECK-NEXT: %{{.*}} = "dialect.op6"() : () -> index + // CHECK-NEXT: } + // clang-format on + mlirOperationDump(op); + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testCreateBlock(MlirContext ctx) { + // CHECK-LABEL: @testCreateBlock + fprintf(stderr, "@testCreateBlock\n"); + + const char *moduleString = "\"dialect.op1\"() ({^bb0:}) : () -> ()\n" + "\"dialect.op2\"() ({^bb0:}) : () -> ()\n"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + MlirOperation op1 = mlirBlockGetFirstOperation(body); + MlirRegion region1 = mlirOperationGetRegion(op1, 0); + MlirBlock block1 = mlirRegionGetFirstBlock(region1); + + MlirOperation op2 = mlirOperationGetNextInBlock(op1); + MlirRegion region2 = mlirOperationGetRegion(op2, 0); + MlirBlock block2 = mlirRegionGetFirstBlock(region2); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + + // Create block before + MlirType indexType = mlirIndexTypeGet(ctx); + MlirLocation unknown = mlirLocationUnknownGet(ctx); + mlirRewriterBaseCreateBlockBefore(rewriter, block1, 1, &indexType, &unknown); + + mlirRewriterBaseSetInsertionPointToEnd(rewriter, body); + + // Clone operation + mlirRewriterBaseClone(rewriter, op1); + + // Clone without regions + mlirRewriterBaseCloneWithoutRegions(rewriter, op1); + + // Clone region before + mlirRewriterBaseCloneRegionBefore(rewriter, region1, block2); + + mlirOperationDump(op); + // clang-format off + // CHECK-NEXT: "builtin.module"() ({ + // CHECK-NEXT: "dialect.op1"() ({ + // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): + // CHECK-NEXT: ^{{.*}}: + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op2"() ({ + // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): + // CHECK-NEXT: ^{{.*}}: + // CHECK-NEXT: ^{{.*}}: + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op1"() ({ + // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): + // CHECK-NEXT: ^{{.*}}: + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op1"() ({ + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: }) : () -> () + // clang-format on + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testInlineRegionBlock(MlirContext ctx) { + // CHECK-LABEL: @testInlineRegionBlock + fprintf(stderr, "@testInlineRegionBlock\n"); + + const char *moduleString = + "\"dialect.op1\"() ({\n" + " ^bb0(%arg0: index):\n" + " \"dialect.op1_in1\"(%arg0) [^bb1] : (index) -> ()\n" + " ^bb1():\n" + " \"dialect.op1_in2\"() : () -> ()\n" + "}) : () -> ()\n" + "\"dialect.op2\"() ({^bb0:}) : () -> ()\n" + "\"dialect.op3\"() ({\n" + " ^bb0(%arg0: index):\n" + " \"dialect.op3_in1\"(%arg0) : (index) -> ()\n" + " ^bb1():\n" + " %x = \"dialect.op3_in2\"() : () -> index\n" + " %y = \"dialect.op3_in3\"() : () -> index\n" + "}) : () -> ()\n" + "\"dialect.op4\"() ({\n" + " ^bb0():\n" + " \"dialect.op4_in1\"() : () -> index\n" + " ^bb1(%arg0: index):\n" + " \"dialect.op4_in2\"(%arg0) : (index) -> ()\n" + "}) : () -> ()\n"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + MlirOperation op1 = mlirBlockGetFirstOperation(body); + MlirRegion region1 = mlirOperationGetRegion(op1, 0); + + MlirOperation op2 = mlirOperationGetNextInBlock(op1); + MlirRegion region2 = mlirOperationGetRegion(op2, 0); + MlirBlock block2 = mlirRegionGetFirstBlock(region2); + + MlirOperation op3 = mlirOperationGetNextInBlock(op2); + MlirRegion region3 = mlirOperationGetRegion(op3, 0); + MlirBlock block3_1 = mlirRegionGetFirstBlock(region3); + MlirBlock block3_2 = mlirBlockGetNextInRegion(block3_1); + MlirOperation op3_in2 = mlirBlockGetFirstOperation(block3_2); + MlirValue op3_in2_res = mlirOperationGetResult(op3_in2, 0); + MlirOperation op3_in3 = mlirOperationGetNextInBlock(op3_in2); + + MlirOperation op4 = mlirOperationGetNextInBlock(op3); + MlirRegion region4 = mlirOperationGetRegion(op4, 0); + MlirBlock block4_1 = mlirRegionGetFirstBlock(region4); + MlirOperation op4_in1 = mlirBlockGetFirstOperation(block4_1); + MlirValue op4_in1_res = mlirOperationGetResult(op4_in1, 0); + MlirBlock block4_2 = mlirBlockGetNextInRegion(block4_1); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + + // Test these three functions + mlirRewriterBaseInlineRegionBefore(rewriter, region1, block2); + mlirRewriterBaseInlineBlockBefore(rewriter, block3_1, op3_in3, 1, + &op3_in2_res); + mlirRewriterBaseMergeBlocks(rewriter, block4_2, block4_1, 1, &op4_in1_res); + + mlirOperationDump(op); + // clang-format off + // CHECK-NEXT: "builtin.module"() ({ + // CHECK-NEXT: "dialect.op1"() ({ + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op2"() ({ + // CHECK-NEXT: ^{{.*}}(%{{.*}}: index): + // CHECK-NEXT: "dialect.op1_in1"(%{{.*}})[^[[bb:.*]]] : (index) -> () + // CHECK-NEXT: ^[[bb]]: + // CHECK-NEXT: "dialect.op1_in2"() : () -> () + // CHECK-NEXT: ^{{.*}}: // no predecessors + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op3"() ({ + // CHECK-NEXT: %{{.*}} = "dialect.op3_in2"() : () -> index + // CHECK-NEXT: "dialect.op3_in1"(%{{.*}}) : (index) -> () + // CHECK-NEXT: %{{.*}} = "dialect.op3_in3"() : () -> index + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op4"() ({ + // CHECK-NEXT: %{{.*}} = "dialect.op4_in1"() : () -> index + // CHECK-NEXT: "dialect.op4_in2"(%{{.*}}) : (index) -> () + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: }) : () -> () + // clang-format on + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testReplaceOp(MlirContext ctx) { + // CHECK-LABEL: @testReplaceOp + fprintf(stderr, "@testReplaceOp\n"); + + const char *moduleString = + "%x, %y, %z = \"dialect.create_values\"() : () -> (index, index, index)\n" + "%x_1, %y_1 = \"dialect.op1\"() : () -> (index, index)\n" + "\"dialect.use_op1\"(%x_1, %y_1) : (index, index) -> ()\n" + "%x_2, %y_2 = \"dialect.op2\"() : () -> (index, index)\n" + "%x_3, %y_3 = \"dialect.op3\"() : () -> (index, index)\n" + "\"dialect.use_op2\"(%x_2, %y_2) : (index, index) -> ()\n"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + // get a handle to all operations/values + MlirOperation createValues = mlirBlockGetFirstOperation(body); + MlirValue x = mlirOperationGetResult(createValues, 0); + MlirValue z = mlirOperationGetResult(createValues, 2); + MlirOperation op1 = mlirOperationGetNextInBlock(createValues); + MlirOperation useOp1 = mlirOperationGetNextInBlock(op1); + MlirOperation op2 = mlirOperationGetNextInBlock(useOp1); + MlirOperation op3 = mlirOperationGetNextInBlock(op2); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + + // Test replace op with values + MlirValue xz[2] = {x, z}; + mlirRewriterBaseReplaceOpWithValues(rewriter, op1, 2, xz); + + // Test replace op with op + mlirRewriterBaseReplaceOpWithOperation(rewriter, op2, op3); + + mlirOperationDump(op); + // clang-format off + // CHECK-NEXT: module { + // CHECK-NEXT: %[[res:.*]]:3 = "dialect.create_values"() : () -> (index, index, index) + // CHECK-NEXT: "dialect.use_op1"(%[[res]]#0, %[[res]]#2) : (index, index) -> () + // CHECK-NEXT: %[[res2:.*]]:2 = "dialect.op3"() : () -> (index, index) + // CHECK-NEXT: "dialect.use_op2"(%[[res2]]#0, %[[res2]]#1) : (index, index) -> () + // CHECK-NEXT: } + // clang-format on + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testErase(MlirContext ctx) { + // CHECK-LABEL: @testErase + fprintf(stderr, "@testErase\n"); + + const char *moduleString = "\"dialect.op_to_erase\"() : () -> ()\n" + "\"dialect.op2\"() ({\n" + "^bb0():\n" + " \"dialect.op2_nested\"() : () -> ()" + "^block_to_erase():\n" + " \"dialect.op2_nested\"() : () -> ()" + "^bb1():\n" + " \"dialect.op2_nested\"() : () -> ()" + "}) : () -> ()\n"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + // get a handle to all operations/values + MlirOperation opToErase = mlirBlockGetFirstOperation(body); + MlirOperation op2 = mlirOperationGetNextInBlock(opToErase); + MlirRegion op2Region = mlirOperationGetRegion(op2, 0); + MlirBlock bb0 = mlirRegionGetFirstBlock(op2Region); + MlirBlock blockToErase = mlirBlockGetNextInRegion(bb0); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + mlirRewriterBaseEraseOp(rewriter, opToErase); + mlirRewriterBaseEraseBlock(rewriter, blockToErase); + + mlirOperationDump(op); + // CHECK-NEXT: module { + // CHECK-NEXT: "dialect.op2"() ({ + // CHECK-NEXT: "dialect.op2_nested"() : () -> () + // CHECK-NEXT: ^{{.*}}: + // CHECK-NEXT: "dialect.op2_nested"() : () -> () + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: } + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testMove(MlirContext ctx) { + // CHECK-LABEL: @testMove + fprintf(stderr, "@testMove\n"); + + const char *moduleString = "\"dialect.op1\"() : () -> ()\n" + "\"dialect.op2\"() ({\n" + "^bb0(%arg0: index):\n" + " \"dialect.op2_1\"(%arg0) : (index) -> ()" + "^bb1(%arg1: index):\n" + " \"dialect.op2_2\"(%arg1) : (index) -> ()" + "}) : () -> ()\n" + "\"dialect.op3\"() : () -> ()\n" + "\"dialect.op4\"() : () -> ()\n"; + + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + // get a handle to all operations/values + MlirOperation op1 = mlirBlockGetFirstOperation(body); + MlirOperation op2 = mlirOperationGetNextInBlock(op1); + MlirOperation op3 = mlirOperationGetNextInBlock(op2); + MlirOperation op4 = mlirOperationGetNextInBlock(op3); + + MlirRegion region2 = mlirOperationGetRegion(op2, 0); + MlirBlock block0 = mlirRegionGetFirstBlock(region2); + MlirBlock block1 = mlirBlockGetNextInRegion(block0); + + // Test move operations. + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + mlirRewriterBaseMoveOpBefore(rewriter, op3, op1); + mlirRewriterBaseMoveOpAfter(rewriter, op4, op1); + mlirRewriterBaseMoveBlockBefore(rewriter, block1, block0); + + mlirOperationDump(op); + // CHECK-NEXT: module { + // CHECK-NEXT: "dialect.op3"() : () -> () + // CHECK-NEXT: "dialect.op1"() : () -> () + // CHECK-NEXT: "dialect.op4"() : () -> () + // CHECK-NEXT: "dialect.op2"() ({ + // CHECK-NEXT: ^{{.*}}(%[[arg0:.*]]: index): + // CHECK-NEXT: "dialect.op2_2"(%[[arg0]]) : (index) -> () + // CHECK-NEXT: ^{{.*}}(%[[arg1:.*]]: index): // no predecessors + // CHECK-NEXT: "dialect.op2_1"(%[[arg1]]) : (index) -> () + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: } + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testOpModification(MlirContext ctx) { + // CHECK-LABEL: @testOpModification + fprintf(stderr, "@testOpModification\n"); + + const char *moduleString = + "%x, %y = \"dialect.op1\"() : () -> (index, index)\n" + "\"dialect.op2\"(%x) : (index) -> ()\n"; + + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + // get a handle to all operations/values + MlirOperation op1 = mlirBlockGetFirstOperation(body); + MlirValue y = mlirOperationGetResult(op1, 1); + MlirOperation op2 = mlirOperationGetNextInBlock(op1); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + mlirRewriterBaseStartOpModification(rewriter, op1); + mlirRewriterBaseCancelOpModification(rewriter, op1); + + mlirRewriterBaseStartOpModification(rewriter, op2); + mlirOperationSetOperand(op2, 0, y); + mlirRewriterBaseFinalizeOpModification(rewriter, op2); + + mlirOperationDump(op); + // CHECK-NEXT: module { + // CHECK-NEXT: %[[xy:.*]]:2 = "dialect.op1"() : () -> (index, index) + // CHECK-NEXT: "dialect.op2"(%[[xy]]#1) : (index) -> () + // CHECK-NEXT: } + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +void testReplaceUses(MlirContext ctx) { + // CHECK-LABEL: @testReplaceUses + fprintf(stderr, "@testReplaceUses\n"); + + const char *moduleString = + // Replace values with values + "%x1, %y1, %z1 = \"dialect.op1\"() : () -> (index, index, index)\n" + "%x2, %y2, %z2 = \"dialect.op2\"() : () -> (index, index, index)\n" + "\"dialect.op1_uses\"(%x1, %y1, %z1) : (index, index, index) -> ()\n" + // Replace op with values + "%x3 = \"dialect.op3\"() : () -> index\n" + "%x4 = \"dialect.op4\"() : () -> index\n" + "\"dialect.op3_uses\"(%x3) : (index) -> ()\n" + // Replace op with op + "%x5 = \"dialect.op5\"() : () -> index\n" + "%x6 = \"dialect.op6\"() : () -> index\n" + "\"dialect.op5_uses\"(%x5) : (index) -> ()\n" + // Replace op in block; + "%x7 = \"dialect.op7\"() : () -> index\n" + "%x8 = \"dialect.op8\"() : () -> index\n" + "\"dialect.op9\"() ({\n" + "^bb0:\n" + " \"dialect.op7_uses\"(%x7) : (index) -> ()\n" + "}): () -> ()\n" + "\"dialect.op7_uses\"(%x7) : (index) -> ()\n" + // Replace value with value except in op + "%x10 = \"dialect.op10\"() : () -> index\n" + "%x11 = \"dialect.op11\"() : () -> index\n" + "\"dialect.op10_uses\"(%x10) : (index) -> ()\n" + "\"dialect.op10_uses\"(%x10) : (index) -> ()\n"; + + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + MlirOperation op = mlirModuleGetOperation(module); + MlirBlock body = mlirModuleGetBody(module); + + // get a handle to all operations/values + MlirOperation op1 = mlirBlockGetFirstOperation(body); + MlirValue x1 = mlirOperationGetResult(op1, 0); + MlirValue y1 = mlirOperationGetResult(op1, 1); + MlirValue z1 = mlirOperationGetResult(op1, 2); + MlirOperation op2 = mlirOperationGetNextInBlock(op1); + MlirValue x2 = mlirOperationGetResult(op2, 0); + MlirValue y2 = mlirOperationGetResult(op2, 1); + MlirValue z2 = mlirOperationGetResult(op2, 2); + MlirOperation op1Uses = mlirOperationGetNextInBlock(op2); + + MlirOperation op3 = mlirOperationGetNextInBlock(op1Uses); + MlirOperation op4 = mlirOperationGetNextInBlock(op3); + MlirValue x4 = mlirOperationGetResult(op4, 0); + MlirOperation op3Uses = mlirOperationGetNextInBlock(op4); + + MlirOperation op5 = mlirOperationGetNextInBlock(op3Uses); + MlirOperation op6 = mlirOperationGetNextInBlock(op5); + MlirOperation op5Uses = mlirOperationGetNextInBlock(op6); + + MlirOperation op7 = mlirOperationGetNextInBlock(op5Uses); + MlirOperation op8 = mlirOperationGetNextInBlock(op7); + MlirValue x8 = mlirOperationGetResult(op8, 0); + MlirOperation op9 = mlirOperationGetNextInBlock(op8); + MlirRegion region9 = mlirOperationGetRegion(op9, 0); + MlirBlock block9 = mlirRegionGetFirstBlock(region9); + MlirOperation op7Uses = mlirOperationGetNextInBlock(op9); + + MlirOperation op10 = mlirOperationGetNextInBlock(op7Uses); + MlirValue x10 = mlirOperationGetResult(op10, 0); + MlirOperation op11 = mlirOperationGetNextInBlock(op10); + MlirValue x11 = mlirOperationGetResult(op11, 0); + MlirOperation op10Uses1 = mlirOperationGetNextInBlock(op11); + + MlirRewriterBase rewriter = mlirIRRewriterCreate(ctx); + + // Replace values + mlirRewriterBaseReplaceAllUsesWith(rewriter, x1, x2); + MlirValue y1z1[2] = {y1, z1}; + MlirValue y2z2[2] = {y2, z2}; + mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, 2, y1z1, y2z2); + + // Replace op with values + mlirRewriterBaseReplaceOpWithValues(rewriter, op3, 1, &x4); + + // Replace op with op + mlirRewriterBaseReplaceOpWithOperation(rewriter, op5, op6); + + // Replace op with op in block + mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op7, 1, &x8, block9); + + // Replace value with value except in op + mlirRewriterBaseReplaceAllUsesExcept(rewriter, x10, x11, op10Uses1); + + mlirOperationDump(op); + // clang-format off + // CHECK-NEXT: module { + // CHECK-NEXT: %{{.*}}:3 = "dialect.op1"() : () -> (index, index, index) + // CHECK-NEXT: %[[res2:.*]]:3 = "dialect.op2"() : () -> (index, index, index) + // CHECK-NEXT: "dialect.op1_uses"(%[[res2]]#0, %[[res2]]#1, %[[res2]]#2) : (index, index, index) -> () + // CHECK-NEXT: %[[res4:.*]] = "dialect.op4"() : () -> index + // CHECK-NEXT: "dialect.op3_uses"(%[[res4]]) : (index) -> () + // CHECK-NEXT: %[[res6:.*]] = "dialect.op6"() : () -> index + // CHECK-NEXT: "dialect.op5_uses"(%[[res6]]) : (index) -> () + // CHECK-NEXT: %[[res7:.*]] = "dialect.op7"() : () -> index + // CHECK-NEXT: %[[res8:.*]] = "dialect.op8"() : () -> index + // CHECK-NEXT: "dialect.op9"() ({ + // CHECK-NEXT: "dialect.op7_uses"(%[[res8]]) : (index) -> () + // CHECK-NEXT: }) : () -> () + // CHECK-NEXT: "dialect.op7_uses"(%[[res7]]) : (index) -> () + // CHECK-NEXT: %[[res10:.*]] = "dialect.op10"() : () -> index + // CHECK-NEXT: %[[res11:.*]] = "dialect.op11"() : () -> index + // CHECK-NEXT: "dialect.op10_uses"(%[[res10]]) : (index) -> () + // CHECK-NEXT: "dialect.op10_uses"(%[[res11]]) : (index) -> () + // CHECK-NEXT: } + // clang-format on + + mlirIRRewriterDestroy(rewriter); + mlirModuleDestroy(module); +} + +int main(void) { + MlirContext ctx = mlirContextCreate(); + mlirContextSetAllowUnregisteredDialects(ctx, true); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("builtin")); + + testInsertionPoint(ctx); + testCreateBlock(ctx); + testInlineRegionBlock(ctx); + testReplaceOp(ctx); + testErase(ctx); + testMove(ctx); + testOpModification(ctx); + testReplaceUses(ctx); + + mlirContextDestroy(ctx); + return 0; +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 45009a78aa49f..df95e5db11f1e 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -105,6 +105,7 @@ set(MLIR_TEST_DEPENDS mlir-capi-llvm-test mlir-capi-pass-test mlir-capi-quant-test + mlir-capi-rewrite-test mlir-capi-sparse-tensor-test mlir-capi-transform-test mlir-capi-transform-interpreter-test diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 1175f87877f9e..98d0ddd9a2be1 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -106,6 +106,7 @@ def add_runtime(name): "mlir-capi-pass-test", "mlir-capi-pdl-test", "mlir-capi-quant-test", + "mlir-capi-rewrite-test", "mlir-capi-sparse-tensor-test", "mlir-capi-transform-test", "mlir-capi-transform-interpreter-test", From 457a82f384c3da67721b4dbd74055583ecb0d8c0 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Tue, 16 Jul 2024 16:01:02 +0100 Subject: [PATCH 2/2] Move comments, and add missing function documenation --- mlir/include/mlir-c/Rewrite.h | 17 ++++++++++------- mlir/lib/CAPI/Transforms/Rewrite.cpp | 6 ------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 09f8a72a0c599..d8f2275b61532 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -49,8 +49,8 @@ mlirRewriterBaseGetContext(MlirRewriterBase rewriter); //===----------------------------------------------------------------------===// /// Insertion points methods -// They do not include functions using Block::iterator or Region::iterator, as -// they are not exposed by the C API yet. This includes methods using +// These do not include functions using Block::iterator or Region::iterator, as +// they are not exposed by the C API yet. Similarly for methods using // `InsertPoint` directly. /// Reset the insertion point to no location. Creating an operation without a @@ -102,6 +102,9 @@ mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning +// These functions do not include the IRMapper, as it is not yet exposed by the +// C API. + /// Add new block with 'argTypes' arguments and set the insertion point to the /// end of it. The block is placed before 'insertBefore'. `locs` contains the /// locations of the inserted arguments, and should match the size of @@ -114,15 +117,17 @@ MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore( MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op); -// The IRMapper is not yet exposed in the CAPI +/// Creates a deep copy of the specified operation. MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op); -// The IRMapper is not yet exposed in the CAPI +/// Creates a deep copy of this operation but keep the operation regions +/// empty. MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions( MlirRewriterBase rewriter, MlirOperation op); -// The IRMapper is not yet exposed in the CAPI, nor Region::iterator. +/// Clone the blocks that belong to "region" before the given position in +/// another region "parent". MLIR_CAPI_EXPORTED void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before); @@ -184,8 +189,6 @@ MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, intptr_t nArgValues, MlirValue const *argValues); -// splitBlock is not implemented as Block::iterator is not exposed by the CAPI - /// Unlink this operation from its current block and insert it right before /// `existingOp` which may be in the same or another block in the same /// function. diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 7f3c833df0910..379f09cf5cc26 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -90,19 +90,16 @@ MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, // Other methods of OpBuilder -// The IRMapper is not yet exposed in the CAPI MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->clone(*unwrap(op))); } -// The IRMapper is not yet exposed in the CAPI MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); } -// The IRMapper is not yet exposed in the CAPI, nor Region::iterator. void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { @@ -113,7 +110,6 @@ void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, /// RewriterBase API //===----------------------------------------------------------------------===// -// Region::iterator is not yet exposed in the CAPI. void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); @@ -160,8 +156,6 @@ void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); } -// splitBlock is not implemented as Block::iterator is not exposed by the CAPI - void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp) { unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));