From 2b7d6b4d4eaf5fe7159f93196e17423771d21fb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 22 Sep 2023 15:10:30 +0000 Subject: [PATCH 1/2] [mlir][transform] Fix handling of transitive include in interpreter. Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR extends the loading missing as follows: in `defineDeclaredSymbols`, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module. --- .../TransformInterpreterPassBase.cpp | 72 ++++++++++++++++--- ...reter-external-symbol-decl-transitive.mlir | 27 +++++++ .../test-interpreter-external-symbol-def.mlir | 5 ++ 3 files changed, 95 insertions(+), 9 deletions(-) create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index d5c65b23e3a21..3c993e417b67e 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { auto readOnlyName = StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName); + // Collect symbols missing in the block. + SmallVector missingSymbols; + LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n"); for (Operation &op : llvm::make_early_inc_range(block)) { LLVM_DEBUG(DBGS() << op << "\n"); auto symbol = dyn_cast(op); @@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { continue; if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()) continue; + LLVM_DEBUG(DBGS() << " -> symbol missing\n"); + missingSymbols.push_back(symbol); + } - LLVM_DEBUG(DBGS() << "looking for definition of symbol " - << symbol.getNameAttr() << ":"); - SymbolTable symbolTable(definitions); - Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr()); + // Resolve missing symbols until they are all resolved. + while (!missingSymbols.empty()) { + SymbolOpInterface symbol = missingSymbols.pop_back_val(); + LLVM_DEBUG(DBGS() << "looking for definition of symbol @" + << symbol.getNameAttr().getValue() << ": "); + SymbolTable definitionsSymbolTable(definitions); + Operation *externalSymbol = + definitionsSymbolTable.lookup(symbol.getNameAttr()); if (!externalSymbol || externalSymbol->getNumRegions() != 1 || externalSymbol->getRegion(0).empty()) { LLVM_DEBUG(llvm::dbgs() << "not found\n"); continue; } - auto symbolFunc = dyn_cast(op); + auto symbolFunc = dyn_cast(symbol.getOperation()); auto externalSymbolFunc = dyn_cast(externalSymbol); if (!symbolFunc || !externalSymbolFunc) { LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); continue; } - LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n"); + LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from " + << externalSymbol->getLoc() << "\n"); if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) { return symbolFunc.emitError() << "external definition has a mismatching signature (" @@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { } } - OpBuilder builder(&op); - builder.setInsertionPoint(&op); - builder.clone(*externalSymbol); + OpBuilder builder(symbol); + builder.setInsertionPoint(symbol); + Operation *newSymbol = builder.clone(*externalSymbol); + builder.setInsertionPoint(newSymbol); symbol->erase(); + + LLVM_DEBUG(DBGS() << "scanning definition of @" + << externalSymbolFunc.getNameAttr().getValue() + << " for symbol usages\n"); + externalSymbolFunc.walk([&](CallOpInterface callOp) { + LLVM_DEBUG(DBGS() << " found symbol usage in:\n" << callOp << "\n"); + CallInterfaceCallable callable = callOp.getCallableForCallee(); + if (!isa(callable)) { + LLVM_DEBUG(DBGS() << " not a 'SymbolRefAttr'\n"); + return WalkResult::advance(); + } + + StringRef callableSymbol = + cast(callable).getLeafReference(); + LLVM_DEBUG(DBGS() << " looking for @" << callableSymbol + << " in definitions: "); + + Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol); + if (!isa(callable)) { + LLVM_DEBUG(llvm::dbgs() << "not found\n"); + return WalkResult::advance(); + } + LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from " + << callableOp->getLoc() << "\n"); + + if (!block.getParent() || !block.getParent()->getParentOp()) { + LLVM_DEBUG(DBGS() << "could not get parent of provided block"); + return WalkResult::advance(); + } + + SymbolTable targetSymbolTable(block.getParent()->getParentOp()); + if (targetSymbolTable.lookup(callableSymbol)) { + LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol + << " already present in target\n"); + return WalkResult::advance(); + } + + LLVM_DEBUG(DBGS() << " cloning op into target\n"); + builder.clone(*callableOp); + + return WalkResult::advance(); + }); } return success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir new file mode 100644 index 0000000000000..0e9fa7c59bc41 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RUN: --verify-diagnostics --split-input-file | FileCheck %s + +// The definition of the @bar named sequence is provided in another file. It +// will be included because of the pass option. That sequence uses another named +// sequence @foo, which should be made available here. Repeated application of +// the same pass, with or without the library option, should not be a problem. +// Note that the same diagnostic produced twice at the same location only +// needs to be matched once. + +// expected-remark @below {{message}} +module attributes {transform.with_named_sequence} { + // CHECK-DAG: transform.named_sequence @foo + // CHECK-DAG: transform.named_sequence @bar + transform.named_sequence private @bar(!transform.any_op {transform.readonly}) + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + include @bar failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir index 1149bda98ab85..9aa2d46d5abb9 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir @@ -1,6 +1,11 @@ // RUN: mlir-opt %s module attributes {transform.with_named_sequence} { + transform.named_sequence @bar(%arg0: !transform.any_op) { + transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op transform.yield From 48a81af250ae8a8789e04b681635090927d817b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 25 Sep 2023 08:11:03 +0000 Subject: [PATCH 2/2] Minor improvements of variable names and debug output. --- .../Transforms/TransformInterpreterPassBase.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 3c993e417b67e..475368f0f406a 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -388,19 +388,19 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { << externalSymbolFunc.getNameAttr().getValue() << " for symbol usages\n"); externalSymbolFunc.walk([&](CallOpInterface callOp) { - LLVM_DEBUG(DBGS() << " found symbol usage in:\n" << callOp << "\n"); + LLVM_DEBUG(DBGS() << " call op in:\n" << callOp << "\n"); CallInterfaceCallable callable = callOp.getCallableForCallee(); if (!isa(callable)) { - LLVM_DEBUG(DBGS() << " not a 'SymbolRefAttr'\n"); + LLVM_DEBUG(DBGS() << " not a symbol usage\n"); return WalkResult::advance(); } - StringRef callableSymbol = + StringRef callableSymbolName = cast(callable).getLeafReference(); - LLVM_DEBUG(DBGS() << " looking for @" << callableSymbol + LLVM_DEBUG(DBGS() << " looking for @" << callableSymbolName << " in definitions: "); - Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol); + Operation *callableOp = definitionsSymbolTable.lookup(callableSymbolName); if (!isa(callable)) { LLVM_DEBUG(llvm::dbgs() << "not found\n"); return WalkResult::advance(); @@ -409,13 +409,13 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { << callableOp->getLoc() << "\n"); if (!block.getParent() || !block.getParent()->getParentOp()) { - LLVM_DEBUG(DBGS() << "could not get parent of provided block"); + LLVM_DEBUG(DBGS() << "could not get parent op of provided block"); return WalkResult::advance(); } SymbolTable targetSymbolTable(block.getParent()->getParentOp()); - if (targetSymbolTable.lookup(callableSymbol)) { - LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol + if (targetSymbolTable.lookup(callableSymbolName)) { + LLVM_DEBUG(DBGS() << " symbol @" << callableSymbolName << " already present in target\n"); return WalkResult::advance(); }