@@ -443,12 +443,12 @@ Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
443443 return extractPtr (builder, loc, kRealPosInComplexNumberStruct );
444444}
445445
446- void ComplexStructBuilder ::setImaginary (OpBuilder &builder, Location loc,
447- Value imaginary ) {
446+ void ComplexStructBuilder::setImaginary (OpBuilder &builder, Location loc,
447+ Value imaginary ) {
448448 setPtr (builder, loc, kImaginaryPosInComplexNumberStruct , imaginary );
449449}
450450
451- Value ComplexStructBuilder ::imaginary (OpBuilder &builder, Location loc) {
451+ Value ComplexStructBuilder::imaginary (OpBuilder &builder, Location loc) {
452452 return extractPtr (builder, loc, kImaginaryPosInComplexNumberStruct );
453453}
454454
@@ -1326,8 +1326,7 @@ using UnsignedShiftRightOpLowering =
13261326 OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
13271327using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
13281328
1329- // Lowerings for operations on complex numbers, `CreateComplexOp`, `ReOp`, and
1330- // `ImOp`.
1329+ // Lowerings for operations on complex numbers.
13311330
13321331struct CreateComplexOpLowering
13331332 : public ConvertOpToLLVMPattern<CreateComplexOp> {
@@ -1385,6 +1384,82 @@ struct ImOpLowering : public ConvertOpToLLVMPattern<ImOp> {
13851384 }
13861385};
13871386
1387+ struct BinaryComplexOperands {
1388+ Value lhsReal, lhsImag, rhsReal, rhsImag;
1389+ };
1390+
1391+ template <typename OpTy>
1392+ BinaryComplexOperands
1393+ unpackBinaryComplexOperands (OpTy op, ArrayRef<Value> operands,
1394+ ConversionPatternRewriter &rewriter) {
1395+ auto bop = cast<OpTy>(op);
1396+ auto loc = bop.getLoc ();
1397+ OperandAdaptor<OpTy> transformed (operands);
1398+
1399+ // Extract real and imaginary values from operands.
1400+ BinaryComplexOperands unpacked;
1401+ ComplexStructBuilder lhs (transformed.lhs ());
1402+ unpacked.lhsReal = lhs.real (rewriter, loc);
1403+ unpacked.lhsImag = lhs.imaginary (rewriter, loc);
1404+ ComplexStructBuilder rhs (transformed.rhs ());
1405+ unpacked.rhsReal = rhs.real (rewriter, loc);
1406+ unpacked.rhsImag = rhs.imaginary (rewriter, loc);
1407+
1408+ return unpacked;
1409+ }
1410+
1411+ struct AddCFOpLowering : public ConvertOpToLLVMPattern <AddCFOp> {
1412+ using ConvertOpToLLVMPattern<AddCFOp>::ConvertOpToLLVMPattern;
1413+
1414+ LogicalResult
1415+ matchAndRewrite (Operation *operation, ArrayRef<Value> operands,
1416+ ConversionPatternRewriter &rewriter) const override {
1417+ auto op = cast<AddCFOp>(operation);
1418+ auto loc = op.getLoc ();
1419+ BinaryComplexOperands arg =
1420+ unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter);
1421+
1422+ // Initialize complex number struct for result.
1423+ auto structType = this ->typeConverter .convertType (op.getType ());
1424+ auto result = ComplexStructBuilder::undef (rewriter, loc, structType);
1425+
1426+ // Emit IR to add complex numbers.
1427+ Value real = rewriter.create <LLVM::FAddOp>(loc, arg.lhsReal , arg.rhsReal );
1428+ Value imag = rewriter.create <LLVM::FAddOp>(loc, arg.lhsImag , arg.rhsImag );
1429+ result.setReal (rewriter, loc, real);
1430+ result.setImaginary (rewriter, loc, imag);
1431+
1432+ rewriter.replaceOp (op, {result});
1433+ return success ();
1434+ }
1435+ };
1436+
1437+ struct SubCFOpLowering : public ConvertOpToLLVMPattern <SubCFOp> {
1438+ using ConvertOpToLLVMPattern<SubCFOp>::ConvertOpToLLVMPattern;
1439+
1440+ LogicalResult
1441+ matchAndRewrite (Operation *operation, ArrayRef<Value> operands,
1442+ ConversionPatternRewriter &rewriter) const override {
1443+ auto op = cast<SubCFOp>(operation);
1444+ auto loc = op.getLoc ();
1445+ BinaryComplexOperands arg =
1446+ unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter);
1447+
1448+ // Initialize complex number struct for result.
1449+ auto structType = this ->typeConverter .convertType (op.getType ());
1450+ auto result = ComplexStructBuilder::undef (rewriter, loc, structType);
1451+
1452+ // Emit IR to substract complex numbers.
1453+ Value real = rewriter.create <LLVM::FSubOp>(loc, arg.lhsReal , arg.rhsReal );
1454+ Value imag = rewriter.create <LLVM::FSubOp>(loc, arg.lhsImag , arg.rhsImag );
1455+ result.setReal (rewriter, loc, real);
1456+ result.setImaginary (rewriter, loc, imag);
1457+
1458+ rewriter.replaceOp (op, {result});
1459+ return success ();
1460+ }
1461+ };
1462+
13881463// Check if the MemRefType `type` is supported by the lowering. We currently
13891464// only support memrefs with identity maps.
13901465static bool isSupportedMemRefType (MemRefType type) {
@@ -2874,6 +2949,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
28742949 // clang-format off
28752950 patterns.insert <
28762951 AbsFOpLowering,
2952+ AddCFOpLowering,
28772953 AddFOpLowering,
28782954 AddIOpLowering,
28792955 AllocaOpLowering,
@@ -2921,6 +2997,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
29212997 SplatOpLowering,
29222998 SplatNdOpLowering,
29232999 SqrtOpLowering,
3000+ SubCFOpLowering,
29243001 SubFOpLowering,
29253002 SubIOpLowering,
29263003 TruncateIOpLowering,
0 commit comments