Skip to content

Commit adea7e7

Browse files
Simon Camphausenmarbre
authored andcommitted
[mlir][emitc] Add comparison operation
This adds a comparison operation to EmitC which supports ==, !=, <=, <, >=, >, <=>. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D158180
1 parent 54784b1 commit adea7e7

File tree

8 files changed

+146
-18
lines changed

8 files changed

+146
-18
lines changed

mlir/include/mlir/Dialect/EmitC/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ add_mlir_dialect(EmitC emitc)
22
add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc)
33

44
set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
5+
mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
6+
mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
57
mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
68
mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
79
add_public_tablegen_target(MLIREmitCAttributesIncGen)

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Interfaces/SideEffectInterfaces.h"
2222

2323
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
24+
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
2425

2526
#define GET_ATTRDEF_CLASSES
2627
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2727
class EmitC_Op<string mnemonic, list<Trait> traits = []>
2828
: Op<EmitC_Dialect, mnemonic, traits>;
2929

30-
// Base class for binary arithmetic operations.
31-
class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
30+
// Base class for binary operations.
31+
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
3232
EmitC_Op<mnemonic, traits> {
3333
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
3434
let results = (outs AnyType);
@@ -39,7 +39,7 @@ class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
3939
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
4040
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
4141

42-
def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
42+
def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
4343
let summary = "Addition operation";
4444
let description = [{
4545
With the `add` operation the arithmetic operator + (addition) can
@@ -150,6 +150,37 @@ def EmitC_CastOp : EmitC_Op<"cast", [
150150
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
151151
}
152152

153+
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
154+
let summary = "Comparison operation";
155+
let description = [{
156+
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
157+
can be applied.
158+
159+
Example:
160+
```mlir
161+
// Custom form of the cmp operation.
162+
%0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
163+
%1 = emitc.cmp lt, %arg2, %arg3 :
164+
(
165+
!emitc.opaque<"std::valarray<float>">,
166+
!emitc.opaque<"std::valarray<float>">
167+
) -> !emitc.opaque<"std::valarray<bool>">
168+
```
169+
```c++
170+
// Code emitted for the operations above.
171+
bool v5 = v1 == v2;
172+
std::valarray<bool> v6 = v3 < v4;
173+
```
174+
}];
175+
176+
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
177+
AnyType:$lhs,
178+
AnyType:$rhs);
179+
let results = (outs AnyType);
180+
181+
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
182+
}
183+
153184
def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
154185
let summary = "Constant operation";
155186
let description = [{
@@ -180,7 +211,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
180211
let hasVerifier = 1;
181212
}
182213

183-
def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> {
214+
def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
184215
let summary = "Division operation";
185216
let description = [{
186217
With the `div` operation the arithmetic operator / (division) can
@@ -248,7 +279,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
248279
let assemblyFormat = "$value attr-dict `:` type($result)";
249280
}
250281

251-
def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
282+
def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
252283
let summary = "Multiplication operation";
253284
let description = [{
254285
With the `mul` operation the arithmetic operator * (multiplication) can
@@ -272,7 +303,7 @@ def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
272303
let results = (outs FloatIntegerIndexOrOpaqueType);
273304
}
274305

275-
def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
306+
def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
276307
let summary = "Remainder operation";
277308
let description = [{
278309
With the `rem` operation the arithmetic operator % (remainder) can
@@ -294,7 +325,7 @@ def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
294325
let results = (outs IntegerIndexOrOpaqueType);
295326
}
296327

297-
def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
328+
def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
298329
let summary = "Subtraction operation";
299330
let description = [{
300331
With the `sub` operation the arithmetic operator - (subtraction) can

mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "mlir/IR/AttrTypeBase.td"
1717
include "mlir/IR/BuiltinAttributeInterfaces.td"
18+
include "mlir/IR/EnumAttr.td"
1819
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
1920

2021
//===----------------------------------------------------------------------===//
@@ -26,6 +27,20 @@ class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
2627
let mnemonic = attrMnemonic;
2728
}
2829

30+
def EmitC_CmpPredicateAttr : I64EnumAttr<
31+
"CmpPredicate", "",
32+
[
33+
I64EnumAttrCase<"eq", 0>,
34+
I64EnumAttrCase<"ne", 1>,
35+
I64EnumAttrCase<"lt", 2>,
36+
I64EnumAttrCase<"le", 3>,
37+
I64EnumAttrCase<"gt", 4>,
38+
I64EnumAttrCase<"ge", 5>,
39+
I64EnumAttrCase<"three_way", 6>,
40+
]> {
41+
let cppNamespace = "::mlir::emitc";
42+
}
43+
2944
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
3045
let summary = "An opaque attribute";
3146

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ LogicalResult emitc::VariableOp::verify() {
257257
#define GET_OP_CLASSES
258258
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
259259

260+
//===----------------------------------------------------------------------===//
261+
// EmitC Enums
262+
//===----------------------------------------------------------------------===//
263+
264+
#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
265+
260266
//===----------------------------------------------------------------------===//
261267
// EmitC Attributes
262268
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
246246
return printConstantOp(emitter, operation, value);
247247
}
248248

249-
static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
250-
Operation *operation,
251-
StringRef binaryArithOperator) {
249+
static LogicalResult printBinaryOperation(CppEmitter &emitter,
250+
Operation *operation,
251+
StringRef binaryOperator) {
252252
raw_ostream &os = emitter.ostream();
253253

254254
if (failed(emitter.emitAssignPrefix(*operation)))
255255
return failure();
256256
os << emitter.getOrCreateName(operation->getOperand(0));
257-
os << " " << binaryArithOperator;
257+
os << " " << binaryOperator;
258258
os << " " << emitter.getOrCreateName(operation->getOperand(1));
259259

260260
return success();
@@ -263,31 +263,65 @@ static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
263263
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
264264
Operation *operation = addOp.getOperation();
265265

266-
return printBinaryArithOperation(emitter, operation, "+");
266+
return printBinaryOperation(emitter, operation, "+");
267267
}
268268

269269
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
270270
Operation *operation = divOp.getOperation();
271271

272-
return printBinaryArithOperation(emitter, operation, "/");
272+
return printBinaryOperation(emitter, operation, "/");
273273
}
274274

275275
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
276276
Operation *operation = mulOp.getOperation();
277277

278-
return printBinaryArithOperation(emitter, operation, "*");
278+
return printBinaryOperation(emitter, operation, "*");
279279
}
280280

281281
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
282282
Operation *operation = remOp.getOperation();
283283

284-
return printBinaryArithOperation(emitter, operation, "%");
284+
return printBinaryOperation(emitter, operation, "%");
285285
}
286286

287287
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
288288
Operation *operation = subOp.getOperation();
289289

290-
return printBinaryArithOperation(emitter, operation, "-");
290+
return printBinaryOperation(emitter, operation, "-");
291+
}
292+
293+
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
294+
Operation *operation = cmpOp.getOperation();
295+
296+
StringRef binaryOperator;
297+
298+
switch (cmpOp.getPredicate()) {
299+
case emitc::CmpPredicate::eq:
300+
binaryOperator = "==";
301+
break;
302+
case emitc::CmpPredicate::ne:
303+
binaryOperator = "!=";
304+
break;
305+
case emitc::CmpPredicate::lt:
306+
binaryOperator = "<";
307+
break;
308+
case emitc::CmpPredicate::le:
309+
binaryOperator = "<=";
310+
break;
311+
case emitc::CmpPredicate::gt:
312+
binaryOperator = ">";
313+
break;
314+
case emitc::CmpPredicate::ge:
315+
binaryOperator = ">=";
316+
break;
317+
case emitc::CmpPredicate::three_way:
318+
binaryOperator = "<=>";
319+
break;
320+
default:
321+
return cmpOp.emitError("unhandled comparison predicate");
322+
}
323+
324+
return printBinaryOperation(emitter, operation, binaryOperator);
291325
}
292326

293327
static LogicalResult printOperation(CppEmitter &emitter,
@@ -977,8 +1011,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
9771011
[&](auto op) { return printOperation(*this, op); })
9781012
// EmitC ops.
9791013
.Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
980-
emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp,
981-
emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
1014+
emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
1015+
emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
9821016
[&](auto op) { return printOperation(*this, op); })
9831017
// Func ops.
9841018
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,21 @@ func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<
7979
%4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
8080
return
8181
}
82+
83+
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
84+
%1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1
85+
%2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1
86+
%3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1
87+
%4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1
88+
%5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1
89+
%6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1
90+
%7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1
91+
%8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
92+
%9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
93+
%10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
94+
%11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
95+
%12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
96+
%13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
97+
%14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
98+
return
99+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
4+
%1 = emitc.cmp eq, %arg0, %arg2 : (i32, i64) -> i1
5+
%2 = emitc.cmp ne, %arg1, %arg3 : (f32, f64) -> i1
6+
%3 = emitc.cmp lt, %arg2, %arg4 : (i64, !emitc.opaque<"unsigned">) -> !emitc.opaque<"int">
7+
%4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
8+
%5 = emitc.cmp gt, %arg6, %arg4 : (!emitc.opaque<"custom">, !emitc.opaque<"unsigned">) -> !emitc.opaque<"custom">
9+
%6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
10+
%7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
11+
12+
return
13+
}
14+
// CHECK-LABEL: void cmp
15+
// CHECK-NEXT: bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V2:[^ ]*]];
16+
// CHECK-NEXT: bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V3:[^ ]*]];
17+
// CHECK-NEXT: int [[V9:[^ ]*]] = [[V2]] < [[V4:[^ ]*]];
18+
// CHECK-NEXT: bool [[V10:[^ ]*]] = [[V3]] <= [[V3]];
19+
// CHECK-NEXT: custom [[V11:[^ ]*]] = [[V6:[^ ]*]] > [[V4]];
20+
// CHECK-NEXT: std::valarray<bool> [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5]];
21+
// CHECK-NEXT: custom [[V13:[^ ]*]] = [[V6]] <=> [[V6]];

0 commit comments

Comments
 (0)