Skip to content

[CIR] Upstream CreateOp for ComplexType with folder #143192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex",
);

let builders = [
AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type,
"mlir::TypedAttr":$real,
AttrBuilderWithInferredContext<(ins "mlir::TypedAttr":$real,
"mlir::TypedAttr":$imag), [{
auto type = cir::ComplexType::get(real.getType());
return $_get(type.getContext(), type, real, imag);
}]>,
];
Expand Down
32 changes: 32 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2353,4 +2353,36 @@ def BaseClassAddrOp : CIR_Op<"base_class_addr"> {
}];
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//

def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let summary = "Create a complex value from its real and imaginary parts";
let description = [{
The `cir.complex.create` operation takes two operands that represent the
real and imaginary part of a complex number, and yields the complex number.

```mlir
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
%2 = cir.complex.create %0, %1 : !cir.double -> !cir.complex<!cir.double>
```
}];

let results = (outs CIR_ComplexType:$result);
let arguments = (ins
CIR_AnyIntOrFloatType:$real,
CIR_AnyIntOrFloatType:$imag
);

let assemblyFormat = [{
$real `,` $imag
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
let hasFolder = 1;
}

#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,8 @@ def CIRRecordType : Type<

def CIR_AnyType : AnyTypeOf<[
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_IntType,
CIR_AnyFloatType, CIR_PointerType, CIR_FuncType, CIR_RecordType
CIR_AnyFloatType, CIR_PointerType, CIR_FuncType, CIR_RecordType,
CIR_ComplexType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ struct MissingFeatures {
// Future CIR operations
static bool awaitOp() { return false; }
static bool callOp() { return false; }
static bool complexCreateOp() { return false; }
static bool complexImagOp() { return false; }
static bool complexRealOp() { return false; }
static bool ifOp() { return false; }
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return CIRBaseBuilderTy::createStore(loc, val, dst.getPointer(), align);
}

mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
mlir::Value imag) {
auto resultComplexTy = cir::ComplexType::get(real.getType());
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}

/// Create a cir.ptr_stride operation to get access to an array element.
/// \p idx is the index of the element to access, \p shouldDecay is true if
/// the result should decay to a pointer to the element type.
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,13 @@ void CIRGenFunction::emitExprAsInit(const Expr *init, const ValueDecl *d,
emitScalarInit(init, getLoc(d->getSourceRange()), lvalue);
return;
case cir::TEK_Complex: {
cgm.errorNYI(init->getSourceRange(), "emitExprAsInit: complex type");
mlir::Value complex = emitComplexExpr(init);
if (capturedByInit)
cgm.errorNYI(init->getSourceRange(),
"emitExprAsInit: complex type captured by init");
mlir::Location loc = getLoc(init->getExprLoc());
emitStoreOfComplex(loc, complex, lvalue,
/*isInit*/ true);
return;
}
case cir::TEK_Aggregate:
Expand Down Expand Up @@ -593,8 +599,8 @@ void CIRGenFunction::emitDecl(const Decl &d) {
// None of these decls require codegen support.
return;

case Decl::Enum: // enum X;
case Decl::Record: // struct/union/class X;
case Decl::Enum: // enum X;
case Decl::Record: // struct/union/class X;
case Decl::CXXRecord: // struct/union/class X; [C++]
case Decl::NamespaceAlias:
case Decl::Using: // using X; [C++]
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,3 +1690,14 @@ mlir::Value CIRGenFunction::emitScalarConstant(
}
return builder.getConstant(getLoc(e->getSourceRange()), constant.getValue());
}

/// An LValue is a candidate for having its loads and stores be made atomic if
/// we are operating under /volatile:ms *and* the LValue itself is volatile and
/// performing such an operation can be performed without a libcall.
bool CIRGenFunction::isLValueSuitableForInlineAtomic(LValue lv) {
if (!cgm.getLangOpts().MSVolatile)
return false;

cgm.errorNYI("LValueSuitableForInlineAtomic LangOpts MSVolatile");
return false;
}
79 changes: 79 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "CIRGenBuilder.h"
#include "CIRGenFunction.h"

#include "clang/AST/StmtVisitor.h"

using namespace clang;
using namespace clang::CIRGen;

namespace {
class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
CIRGenFunction &cgf;
CIRGenBuilderTy &builder;

public:
explicit ComplexExprEmitter(CIRGenFunction &cgf)
: cgf(cgf), builder(cgf.getBuilder()) {}

/// Store the specified real/imag parts into the
/// specified value pointer.
void emitStoreOfComplex(mlir::Location loc, mlir::Value val, LValue lv,
bool isInit);

mlir::Value VisitInitListExpr(InitListExpr *e);
};

} // namespace

static const ComplexType *getComplexType(QualType type) {
type = type.getCanonicalType();
if (const ComplexType *comp = dyn_cast<ComplexType>(type))
return comp;
return cast<ComplexType>(cast<AtomicType>(type)->getValueType());
}

void ComplexExprEmitter::emitStoreOfComplex(mlir::Location loc, mlir::Value val,
LValue lv, bool isInit) {
if (lv.getType()->isAtomicType() ||
(!isInit && cgf.isLValueSuitableForInlineAtomic(lv))) {
cgf.cgm.errorNYI("StoreOfComplex with Atomic LV");
return;
}

const Address destAddr = lv.getAddress();
builder.createStore(loc, val, destAddr);
}

mlir::Value ComplexExprEmitter::VisitInitListExpr(InitListExpr *e) {
mlir::Location loc = cgf.getLoc(e->getExprLoc());
if (e->getNumInits() == 2) {
mlir::Value real = cgf.emitScalarExpr(e->getInit(0));
mlir::Value imag = cgf.emitScalarExpr(e->getInit(1));
return builder.createComplexCreate(loc, real, imag);
}

if (e->getNumInits() == 1) {
cgf.cgm.errorNYI("Create Complex with InitList with size 1");
return {};
}

assert(e->getNumInits() == 0 && "Unexpected number of inits");
QualType complexElemTy =
e->getType()->castAs<clang::ComplexType>()->getElementType();
mlir::Type complexElemLLVMTy = cgf.convertType(complexElemTy);
mlir::TypedAttr defaultValue = builder.getZeroInitAttr(complexElemLLVMTy);
auto complexAttr = cir::ConstComplexAttr::get(defaultValue, defaultValue);
return builder.create<cir::ConstantOp>(loc, complexAttr);
}

mlir::Value CIRGenFunction::emitComplexExpr(const Expr *e) {
assert(e && getComplexType(e->getType()) &&
"Invalid complex expression to emit");

return ComplexExprEmitter(*this).Visit(const_cast<Expr *>(e));
}

void CIRGenFunction::emitStoreOfComplex(mlir::Location loc, mlir::Value v,
LValue dest, bool isInit) {
ComplexExprEmitter(*this).emitStoreOfComplex(loc, v, dest, isInit);
}
9 changes: 9 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ class CIRGenFunction : public CIRGenTypeCache {
PrototypeWrapper(const clang::ObjCMethodDecl *md) : p(md) {}
};

bool isLValueSuitableForInlineAtomic(LValue lv);

/// An abstract representation of regular/ObjC call/message targets.
class AbstractCallee {
/// The function declaration of the callee.
Expand Down Expand Up @@ -860,6 +862,10 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::LogicalResult emitForStmt(const clang::ForStmt &s);

/// Emit the computation of the specified expression of complex type,
/// returning the result.
mlir::Value emitComplexExpr(const Expr *e);

void emitCompoundStmt(const clang::CompoundStmt &s);

void emitCompoundStmtWithoutScope(const clang::CompoundStmt &s);
Expand Down Expand Up @@ -961,6 +967,9 @@ class CIRGenFunction : public CIRGenTypeCache {

void emitStaticVarDecl(const VarDecl &d, cir::GlobalLinkageKind linkage);

void emitStoreOfComplex(mlir::Location loc, mlir::Value v, LValue dest,
bool isInit);

void emitStoreOfScalar(mlir::Value value, Address addr, bool isVolatile,
clang::QualType ty, bool isInit = false,
bool isNontemporal = false);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_clang_library(clangCIR
CIRGenDeclOpenACC.cpp
CIRGenExpr.cpp
CIRGenExprAggregate.cpp
CIRGenExprComplex.cpp
CIRGenExprConstant.cpp
CIRGenExprScalar.cpp
CIRGenFunction.cpp
Expand Down
27 changes: 27 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,33 @@ OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
vecTy, mlir::ArrayAttr::get(getContext(), elements));
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//

LogicalResult cir::ComplexCreateOp::verify() {
if (getType().getElementType() != getReal().getType()) {
emitOpError()
<< "operand type of cir.complex.create does not match its result type";
return failure();
}

return success();
}

OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
mlir::Attribute real = adaptor.getReal();
mlir::Attribute imag = adaptor.getImag();
if (!real || !imag)
return {};

// When both of real and imag are constants, we can fold the operation into an
// `#cir.const_complex` operation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a folder :)

auto realAttr = mlir::cast<mlir::TypedAttr>(real);
auto imagAttr = mlir::cast<mlir::TypedAttr>(imag);
return cir::ConstComplexAttr::get(realAttr, imagAttr);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,15 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());

// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp,
VecTernaryOp>(op))
ComplexCreateOp, VecCreateOp, VecExtractOp, VecShuffleOp,
VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});

Expand Down
48 changes: 46 additions & 2 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,32 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
getTypeConverter()));
return mlir::success();
} else {
} else if (auto complexTy = mlir::dyn_cast<cir::ComplexType>(op.getType())) {
auto complexAttr = mlir::cast<cir::ConstComplexAttr>(op.getValue());
mlir::Type complexElemTy = complexTy.getElementType();
mlir::Type complexElemLLVMTy = typeConverter->convertType(complexElemTy);

mlir::Attribute components[2];
if (mlir::isa<cir::IntType>(complexElemTy)) {
components[0] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
} else {
components[0] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
}

attr = rewriter.getArrayAttr(components);
}

else {
return op.emitError() << "unsupported constant type " << op.getType();
}

Expand Down Expand Up @@ -1810,7 +1835,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecSplatOpLowering,
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMComplexCreateOpLowering
// clang-format on
>(converter, patterns.getContext());

Expand Down Expand Up @@ -2096,6 +2122,24 @@ mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
cir::ComplexCreateOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type complexLLVMTy =
getTypeConverter()->convertType(op.getResult().getType());
auto initialComplex =
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);

auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
op->getLoc(), initialComplex, adaptor.getReal(), 0);

auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
op->getLoc(), realComplex, adaptor.getImag(), 1);

rewriter.replaceOp(op, complex);
return mlir::success();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,16 @@ class CIRToLLVMVecTernaryOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMComplexCreateOpLowering
: public mlir::OpConversionPattern<cir::ComplexCreateOp> {
public:
using mlir::OpConversionPattern<cir::ComplexCreateOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::ComplexCreateOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

} // namespace direct
} // namespace cir

Expand Down
Loading
Loading