diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 22e144738783..4c310afb0ee1 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -548,6 +548,11 @@ LogicalResult cir::CastOp::verify() { mlir::isa(resType)) return success(); + // Handle the pointer to member function types. + if (mlir::isa(srcType) && + mlir::isa(resType)) + return success(); + // This is the only cast kind where we don't want vector types to decay // into the element type. if ((!mlir::isa(getSrc().getType()) || @@ -724,8 +729,9 @@ LogicalResult cir::CastOp::verify() { return success(); } case cir::CastKind::member_ptr_to_bool: { - if (!mlir::isa(srcType)) - return emitOpError() << "requires !cir.data_member type for source"; + if (!mlir::isa(srcType)) + return emitOpError() + << "requires !cir.data_member or !cir.method type for source"; if (!mlir::isa(resType)) return emitOpError() << "requires !cir.bool type for result"; return success(); diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index c36923a41e35..266a8e70288d 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -134,6 +134,15 @@ class CIRCXXABI { virtual mlir::Value lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value lowerMethodBitcast(cir::CastOp op, + mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value lowerMethodToBoolCast(cir::CastOp op, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const = 0; }; /// Creates an Itanium-family ABI. diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index 2819adb25b8c..683d03b8f05d 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -114,6 +114,13 @@ class ItaniumCXXABI : public CIRCXXABI { mlir::Value lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; + + mlir::Value lowerMethodBitcast(cir::CastOp op, mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const override; + + mlir::Value lowerMethodToBoolCast(cir::CastOp op, mlir::Value loweredSrc, + mlir::OpBuilder &builder) const override; }; } // namespace @@ -556,6 +563,30 @@ ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, nullValue); } +mlir::Value ItaniumCXXABI::lowerMethodBitcast(cir::CastOp op, + mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const { + return loweredSrc; +} + +mlir::Value +ItaniumCXXABI::lowerMethodToBoolCast(cir::CastOp op, mlir::Value loweredSrc, + mlir::OpBuilder &builder) const { + // Itanium C++ ABI 2.3.2: + // + // In the standard representation, a null member function pointer is + // represented with ptr set to a null pointer. The value of adj is + // unspecified for null member function pointers. + cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(LM); + mlir::Value ptrdiffZero = builder.create( + op.getLoc(), ptrdiffCIRTy, cir::IntAttr::get(ptrdiffCIRTy, 0)); + mlir::Value ptrField = builder.create( + op.getLoc(), ptrdiffCIRTy, loweredSrc, 0); + return builder.create(op.getLoc(), cir::CmpOpKind::ne, ptrField, + ptrdiffZero); +} + CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) { switch (LM.getCXXABIKind()) { // Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 12ce2fb6eda6..131b14cd9aef 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1272,14 +1272,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( auto dstTy = castOp.getType(); auto llvmDstTy = getTypeConverter()->convertType(dstTy); - if (mlir::isa(castOp.getSrc().getType())) { - mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast( - castOp, llvmDstTy, src, rewriter); + if (mlir::isa( + castOp.getSrc().getType())) { + mlir::Value loweredResult; + if (mlir::isa(castOp.getSrc().getType())) + loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast( + castOp, llvmDstTy, src, rewriter); + else + loweredResult = lowerMod->getCXXABI().lowerMethodBitcast( + castOp, llvmDstTy, src, rewriter); rewriter.replaceOp(castOp, loweredResult); return mlir::success(); } - if (mlir::isa(castOp.getSrc().getType())) - llvm_unreachable("NYI"); auto llvmSrcVal = adaptor.getOperands().front(); rewriter.replaceOpWithNewOp(castOp, llvmDstTy, @@ -1308,7 +1312,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( case cir::CastKind::member_ptr_to_bool: { mlir::Value loweredResult; if (mlir::isa(castOp.getSrc().getType())) - llvm_unreachable("NYI"); + loweredResult = + lowerMod->getCXXABI().lowerMethodToBoolCast(castOp, src, rewriter); else loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast( castOp, src, rewriter); diff --git a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp index 5baf9c9bd23a..612c589e6221 100644 --- a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp @@ -118,3 +118,39 @@ bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { // LLVM-NEXT: %[[#adj_cmp:]] = icmp ne i64 %[[#lhs_adj]], %[[#rhs_adj]] // LLVM-NEXT: %[[#tmp:]] = and i1 %[[#ptr_null]], %[[#adj_cmp]] // LLVM-NEXT: %{{.+}} = or i1 %[[#tmp]], %[[#ptr_cmp]] + +struct Bar { + void m4(); +}; + +bool memfunc_to_bool(void (Foo::*func)(int)) { + return func; +} + +// CIR-LABEL: @_Z15memfunc_to_boolM3FooFviE +// CIR: %{{.+}} = cir.cast(member_ptr_to_bool, %{{.+}} : !cir.method in !ty_Foo>), !cir.bool +// CIR: } + +// LLVM-LABEL: @_Z15memfunc_to_boolM3FooFviE +// LLVM: %[[#memfunc:]] = load { i64, i64 }, ptr %{{.+}} +// LLVM-NEXT: %[[#ptr:]] = extractvalue { i64, i64 } %[[#memfunc]], 0 +// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#ptr]], 0 +// LLVM: } + +auto memfunc_reinterpret(void (Foo::*func)(int)) -> void (Bar::*)() { + return reinterpret_cast(func); +} + +// CIR-LABEL: @_Z19memfunc_reinterpretM3FooFviE +// CIR: %{{.+}} = cir.cast(bitcast, %{{.+}} : !cir.method in !ty_Foo>), !cir.method in !ty_Bar> +// CIR: } + +// LLVM-LABEL: @_Z19memfunc_reinterpretM3FooFviE +// LLVM-NEXT: %[[#arg_slot:]] = alloca { i64, i64 }, i64 1 +// LLVM-NEXT: %[[#ret_slot:]] = alloca { i64, i64 }, i64 1 +// LLVM-NEXT: store { i64, i64 } %{{.+}}, ptr %[[#arg_slot]] +// LLVM-NEXT: %[[#tmp:]] = load { i64, i64 }, ptr %[[#arg_slot]] +// LLVM-NEXT: store { i64, i64 } %[[#tmp]], ptr %[[#ret_slot]] +// LLVM-NEXT: %[[#ret:]] = load { i64, i64 }, ptr %[[#ret_slot]] +// LLVM-NEXT: ret { i64, i64 } %[[#ret]] +// LLVM-NEXT: }