Skip to content

Commit 5d5f61f

Browse files
committed
[MLIR] Add complex addition and substraction to the standard dialect
Complex addition and substraction are the first two binary operations on complex numbers. Remaining operations will follow the same pattern. Differential Revision: https://reviews.llvm.org/D79479
1 parent 9c198b5 commit 5d5f61f

File tree

4 files changed

+178
-6
lines changed

4 files changed

+178
-6
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
109109
// integer tensor. The custom assembly form of the operation is as follows
110110
//
111111
// <op>i %0, %1 : i32
112+
//
112113
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
113114
ArithmeticOp<mnemonic, traits>,
114115
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@@ -121,10 +122,23 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
121122
// is as follows
122123
//
123124
// <op>f %0, %1 : f32
125+
//
124126
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
125127
ArithmeticOp<mnemonic, traits>,
126128
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
127129

130+
// Base class for standard arithmetic operations on complex numbers with a
131+
// floating-point element type.
132+
// These operations take two operands and return one result, all of which must
133+
// be complex numbers of the same type.
134+
// The assembly format is as follows
135+
//
136+
// <op>cf %0, %1 : complex<f32>
137+
//
138+
class ComplexFloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
139+
ArithmeticOp<mnemonic, traits>,
140+
Arguments<(ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs)>;
141+
128142
// Base class for memref allocating ops: alloca and alloc.
129143
//
130144
// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@@ -201,6 +215,26 @@ def AbsFOp : FloatUnaryOp<"absf"> {
201215
}];
202216
}
203217

218+
//===----------------------------------------------------------------------===//
219+
// AddCFOp
220+
//===----------------------------------------------------------------------===//
221+
222+
def AddCFOp : ComplexFloatArithmeticOp<"addcf"> {
223+
let summary = "complex number addition";
224+
let description = [{
225+
The `addcf` operation takes two complex number operands and returns their
226+
sum, a single complex number.
227+
All operands and result must be of the same type, a complex number with a
228+
floating-point element type.
229+
230+
Example:
231+
232+
```mlir
233+
%a = addcf %b, %c : complex<f32>
234+
```
235+
}];
236+
}
237+
204238
//===----------------------------------------------------------------------===//
205239
// AddFOp
206240
//===----------------------------------------------------------------------===//
@@ -2407,6 +2441,26 @@ def StoreOp : Std_Op<"store",
24072441
}];
24082442
}
24092443

2444+
//===----------------------------------------------------------------------===//
2445+
// SubCFOp
2446+
//===----------------------------------------------------------------------===//
2447+
2448+
def SubCFOp : ComplexFloatArithmeticOp<"subcf"> {
2449+
let summary = "complex number subtraction";
2450+
let description = [{
2451+
The `subcf` operation takes two complex number operands and returns their
2452+
difference, a single complex number.
2453+
All operands and result must be of the same type, a complex number with a
2454+
floating-point element type.
2455+
2456+
Example:
2457+
2458+
```mlir
2459+
%a = subcf %b, %c : complex<f32>
2460+
```
2461+
}];
2462+
}
2463+
24102464
//===----------------------------------------------------------------------===//
24112465
// SubFOp
24122466
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpBase.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,6 @@ def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
719719
SignlessIntegerLike.predicate, FloatLike.predicate]>,
720720
"signless-integer-like or floating-point-like">;
721721

722-
723722
//===----------------------------------------------------------------------===//
724723
// Attribute definitions
725724
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,12 @@ Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
443443
return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
444444
}
445445

446-
void ComplexStructBuilder ::setImaginary(OpBuilder &builder, Location loc,
447-
Value imaginary) {
446+
void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
447+
Value imaginary) {
448448
setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
449449
}
450450

451-
Value ComplexStructBuilder ::imaginary(OpBuilder &builder, Location loc) {
451+
Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
452452
return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
453453
}
454454

@@ -1326,8 +1326,7 @@ using UnsignedShiftRightOpLowering =
13261326
OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
13271327
using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
13281328

1329-
// Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and
1330-
// `ImOp`.
1329+
// Lowerings for operations on complex numbers.
13311330

13321331
struct CreateComplexOpLowering
13331332
: public ConvertOpToLLVMPattern<CreateComplexOp> {
@@ -1385,6 +1384,82 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
13851384
}
13861385
};
13871386

1387+
struct BinaryComplexOperands {
1388+
Value lhsReal, lhsImag, rhsReal, rhsImag;
1389+
};
1390+
1391+
template <typename OpTy>
1392+
BinaryComplexOperands
1393+
unpackBinaryComplexOperands(OpTy op, ArrayRef<Value> operands,
1394+
ConversionPatternRewriter &rewriter) {
1395+
auto bop = cast<OpTy>(op);
1396+
auto loc = bop.getLoc();
1397+
OperandAdaptor<OpTy> transformed(operands);
1398+
1399+
// Extract real and imaginary values from operands.
1400+
BinaryComplexOperands unpacked;
1401+
ComplexStructBuilder lhs(transformed.lhs());
1402+
unpacked.lhsReal = lhs.real(rewriter, loc);
1403+
unpacked.lhsImag = lhs.imaginary(rewriter, loc);
1404+
ComplexStructBuilder rhs(transformed.rhs());
1405+
unpacked.rhsReal = rhs.real(rewriter, loc);
1406+
unpacked.rhsImag = rhs.imaginary(rewriter, loc);
1407+
1408+
return unpacked;
1409+
}
1410+
1411+
struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
1412+
using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
1413+
1414+
LogicalResult
1415+
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
1416+
ConversionPatternRewriter &rewriter) const override {
1417+
auto op = cast<AddCFOp>(operation);
1418+
auto loc = op.getLoc();
1419+
BinaryComplexOperands arg =
1420+
unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
1421+
1422+
// Initialize complex number struct for result.
1423+
auto structType = this->typeConverter.convertType(op.getType());
1424+
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1425+
1426+
// Emit IR to add complex numbers.
1427+
Value real = rewriter.create<LLVM::FAddOp>(loc, arg.lhsReal, arg.rhsReal);
1428+
Value imag = rewriter.create<LLVM::FAddOp>(loc, arg.lhsImag, arg.rhsImag);
1429+
result.setReal(rewriter, loc, real);
1430+
result.setImaginary(rewriter, loc, imag);
1431+
1432+
rewriter.replaceOp(op, {result});
1433+
return success();
1434+
}
1435+
};
1436+
1437+
struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
1438+
using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
1439+
1440+
LogicalResult
1441+
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
1442+
ConversionPatternRewriter &rewriter) const override {
1443+
auto op = cast<SubCFOp>(operation);
1444+
auto loc = op.getLoc();
1445+
BinaryComplexOperands arg =
1446+
unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
1447+
1448+
// Initialize complex number struct for result.
1449+
auto structType = this->typeConverter.convertType(op.getType());
1450+
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
1451+
1452+
// Emit IR to substract complex numbers.
1453+
Value real = rewriter.create<LLVM::FSubOp>(loc, arg.lhsReal, arg.rhsReal);
1454+
Value imag = rewriter.create<LLVM::FSubOp>(loc, arg.lhsImag, arg.rhsImag);
1455+
result.setReal(rewriter, loc, real);
1456+
result.setImaginary(rewriter, loc, imag);
1457+
1458+
rewriter.replaceOp(op, {result});
1459+
return success();
1460+
}
1461+
};
1462+
13881463
// Check if the MemRefType `type` is supported by the lowering. We currently
13891464
// only support memrefs with identity maps.
13901465
static bool isSupportedMemRefType(MemRefType type) {
@@ -2874,6 +2949,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
28742949
// clang-format off
28752950
patterns.insert<
28762951
AbsFOpLowering,
2952+
AddCFOpLowering,
28772953
AddFOpLowering,
28782954
AddIOpLowering,
28792955
AllocaOpLowering,
@@ -2921,6 +2997,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
29212997
SplatOpLowering,
29222998
SplatNdOpLowering,
29232999
SqrtOpLowering,
3000+
SubCFOpLowering,
29243001
SubFOpLowering,
29253002
SubIOpLowering,
29263003
TruncateIOpLowering,

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,48 @@ func @complex_numbers() {
8383
return
8484
}
8585

86+
// CHECK-LABEL: llvm.func @complex_addition()
87+
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }">
88+
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }">
89+
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }">
90+
// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }">
91+
// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }">
92+
// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : !llvm.double
93+
// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : !llvm.double
94+
// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }">
95+
// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }">
96+
func @complex_addition() {
97+
%a_re = constant 1.2 : f64
98+
%a_im = constant 3.4 : f64
99+
%a = create_complex %a_re, %a_im : complex<f64>
100+
%b_re = constant 5.6 : f64
101+
%b_im = constant 7.8 : f64
102+
%b = create_complex %b_re, %b_im : complex<f64>
103+
%c = addcf %a, %b : complex<f64>
104+
return
105+
}
106+
107+
// CHECK-LABEL: llvm.func @complex_substraction()
108+
// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm<"{ double, double }">
109+
// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm<"{ double, double }">
110+
// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"{ double, double }">
111+
// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"{ double, double }">
112+
// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm<"{ double, double }">
113+
// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : !llvm.double
114+
// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : !llvm.double
115+
// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm<"{ double, double }">
116+
// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm<"{ double, double }">
117+
func @complex_substraction() {
118+
%a_re = constant 1.2 : f64
119+
%a_im = constant 3.4 : f64
120+
%a = create_complex %a_re, %a_im : complex<f64>
121+
%b_re = constant 5.6 : f64
122+
%b_im = constant 7.8 : f64
123+
%b = create_complex %b_re, %b_im : complex<f64>
124+
%c = subcf %a, %b : complex<f64>
125+
return
126+
}
127+
86128
// CHECK-LABEL: func @simple_caller() {
87129
// CHECK-NEXT: llvm.call @simple_loop() : () -> ()
88130
// CHECK-NEXT: llvm.return

0 commit comments

Comments
 (0)