Skip to content

Commit b89ed4c

Browse files
committed
[MLIR] Add support for overloading interface methods
1 parent 224da97 commit b89ed4c

File tree

8 files changed

+43
-11
lines changed

8 files changed

+43
-11
lines changed

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def Async_ExecuteOp :
101101
let extraClassDeclaration = [{
102102
using BodyBuilderFn =
103103
function_ref<void(OpBuilder &, Location, ValueRange)>;
104-
105104
}];
106105
}
107106

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
2525

2626
def AlternativesOp : TransformDialectOp<"alternatives",
2727
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
28-
["getEntrySuccessorOperands", "getSuccessorRegions",
28+
["getEntrySuccessorOperands",
2929
"getRegionInvocationBounds"]>,
3030
DeclareOpInterfaceMethods<TransformOpInterface>,
3131
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
624624
[DeclareOpInterfaceMethods<TransformOpInterface>,
625625
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
626626
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
627-
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
627+
"getEntrySuccessorOperands"]>,
628628
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
629629
]> {
630630
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
12371237

12381238
def SequenceOp : TransformDialectOp<"sequence",
12391239
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
1240-
["getEntrySuccessorOperands", "getSuccessorRegions",
1240+
["getEntrySuccessorOperands",
12411241
"getRegionInvocationBounds"]>,
12421242
MatchOpInterface,
12431243
DeclareOpInterfaceMethods<TransformOpInterface>,

mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
6363

6464
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
6565
DeclareOpInterfaceMethods<RegionBranchOpInterface,
66-
["getEntrySuccessorOperands", "getSuccessorRegions",
66+
["getEntrySuccessorOperands",
6767
"getRegionInvocationBounds"]>,
6868
DeclareOpInterfaceMethods<TransformOpInterface>,
6969
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1010
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1111
#include "mlir/IR/OpImplementation.h"
12+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1213
#include "llvm/Support/Debug.h"
1314

1415
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
@@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
112113
}
113114

114115
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
115-
RegionBranchPoint point) {
116+
RegionSuccessor successor) {
116117
// No operands will be forwarded to the region(s).
117118
return getOperands().slice(0, 0);
118119
}
@@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions(
128129
for (Region &alternative : getAlternatives())
129130
regions.emplace_back(&alternative, Block::BlockArgListType());
130131
else
131-
regions.emplace_back(getOperation()->getResults());
132+
regions.emplace_back(getOperation(), getOperation()->getResults());
132133
}
133134

134135
void transform::tune::AlternativesOp::getRegionInvocationBounds(

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2567,7 +2567,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
25672567

25682568
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
25692569
NoTerminator,
2570-
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
2570+
DeclareOpInterfaceMethods<RegionBranchOpInterface>
25712571
]> {
25722572
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
25732573
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,14 @@ class OpEmitter {
789789
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
790790
bool declaration = true);
791791

792+
// Generate a `using` declaration for the op interface method to include
793+
// the default implementation from the interface trait.
794+
// This is needed when the interface defines multiple methods with the same
795+
// name, but some have a default implementation and some don't.
796+
UsingDeclaration *
797+
genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
798+
const tblgen::InterfaceMethod &method);
799+
792800
// Generate the side effect interface methods.
793801
void genSideEffectInterfaceMethods();
794802

@@ -815,6 +823,10 @@ class OpEmitter {
815823

816824
// Helper for emitting op code.
817825
OpOrAdaptorHelper emitHelper;
826+
827+
// Keep track of the interface using declarations that have been generated to
828+
// avoid duplicates.
829+
llvm::StringSet<> interfaceUsingNames;
818830
};
819831

820832
} // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
36723684
// Don't declare if the method has a default implementation and the op
36733685
// didn't request that it always be declared.
36743686
if (method.getDefaultImplementation() &&
3675-
!alwaysDeclaredMethods.count(method.getName()))
3687+
!alwaysDeclaredMethods.count(method.getName())) {
3688+
genOpInterfaceMethodUsingDecl(opTrait, method);
36763689
continue;
3690+
}
36773691
// Interface methods are allowed to overlap with existing methods, so don't
36783692
// check if pruned.
36793693
(void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
36923706
std::move(paramList));
36933707
}
36943708

3709+
UsingDeclaration *
3710+
OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
3711+
const InterfaceMethod &method) {
3712+
std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
3713+
op.getQualCppClassName() + ">::" + method.getName())
3714+
.str();
3715+
if (interfaceUsingNames.insert(name).second)
3716+
return opClass.declare<UsingDeclaration>(std::move(name));
3717+
return nullptr;
3718+
}
3719+
36953720
void OpEmitter::genOpInterfaceMethods() {
36963721
for (const auto &trait : op.getTraits()) {
36973722
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))

mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
374374
os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
375375

376376
// Add the arguments to the call.
377-
os << method.getDedupName() << '(';
377+
os << method.getName() << '(';
378378
llvm::interleaveComma(
379379
method.getArguments(), os,
380380
[&](const InterfaceMethod::Argument &arg) { os << arg.name; });
@@ -479,7 +479,7 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
479479
emitInterfaceMethodDoc(method, os, " ");
480480
os << " " << (method.isStatic() ? "static " : "");
481481
emitCPPType(method.getReturnType(), os);
482-
emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
482+
emitMethodNameAndArgs(method, method.getName(), os, valueType,
483483
/*addThisArg=*/false,
484484
/*addConst=*/!isOpInterface && !method.isStatic());
485485
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)

mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ struct MutuallyExclusiveRegionsOp
4646
// Regions have no successors.
4747
void getSuccessorRegions(RegionBranchPoint point,
4848
SmallVectorImpl<RegionSuccessor> &regions) {}
49+
using RegionBranchOpInterface::Trait<
50+
MutuallyExclusiveRegionsOp>::getSuccessorRegions;
4951
};
5052

5153
/// All regions of this op call each other in a large circle.
@@ -70,6 +72,7 @@ struct LoopRegionsOp
7072
regions.push_back(RegionSuccessor(region));
7173
}
7274
}
75+
using RegionBranchOpInterface::Trait<LoopRegionsOp>::getSuccessorRegions;
7376
};
7477

7578
/// Each region branches back it itself or the parent.
@@ -93,6 +96,8 @@ struct DoubleLoopRegionsOp
9396
regions.push_back(RegionSuccessor(region));
9497
}
9598
}
99+
using RegionBranchOpInterface::Trait<
100+
DoubleLoopRegionsOp>::getSuccessorRegions;
96101
};
97102

98103
/// Regions are executed sequentially.
@@ -113,6 +118,8 @@ struct SequentialRegionsOp
113118
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
114119
}
115120
}
121+
using RegionBranchOpInterface::Trait<
122+
SequentialRegionsOp>::getSuccessorRegions;
116123
};
117124

118125
/// A dialect putting all the above together.

0 commit comments

Comments
 (0)