diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 0c4975a13d301..5abc112ab8c7a 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -692,6 +692,22 @@ static LogicalResult printOperation(CppEmitter &emitter, return failure(); os << callOpaqueOp.getCallee(); + // Template arguments can't refer to SSA values and as such the template + // arguments which are supplied in form of attributes can be emitted as is. We + // don't need to handle integer attributes specially like we do for arguments + // - see below. + auto emitTemplateArgs = [&](Attribute attr) -> LogicalResult { + return emitter.emitAttribute(op.getLoc(), attr); + }; + + if (callOpaqueOp.getTemplateArgs()) { + os << "<"; + if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, + emitTemplateArgs))) + return failure(); + os << ">"; + } + auto emitArgs = [&](Attribute attr) -> LogicalResult { if (auto t = dyn_cast(attr)) { // Index attributes are treated specially as operand index. @@ -711,14 +727,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); }; - if (callOpaqueOp.getTemplateArgs()) { - os << "<"; - if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, - emitArgs))) - return failure(); - os << ">"; - } - os << "("; LogicalResult emittedArgs = diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index 45fef618621cc..294e6af65bf14 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -109,3 +109,11 @@ func.func @apply() -> !emitc.ptr { func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { return } + +// CHECK: call_opaque_with_template_arg +func.func @call_opaque_with_template_arg() { + emitc.call_opaque "init_tile"() {template_args = [512 : index]} : () -> () + // CHECK-NEXT: init_tile<512>(); + // CHECK-NEXT: return + return +}