Skip to content

[CIR] Add special type and new operations for vptrs #1745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 103 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,7 @@ def CIR_GetGlobalOp : CIR_Op<"get_global", [
// VTableAddrPointOp
//===----------------------------------------------------------------------===//

def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the vtable (global variable) address point";
Expand All @@ -2589,17 +2589,18 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
(address point) of a C++ virtual table. An object internal `__vptr`
gets initializated on top of the value returned by this operation.

`address_point.index` (vtable index) provides the appropriate vtable within the vtable group
(as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address
point within that vtable.
`address_point.index` (vtable index) provides the appropriate vtable within
the vtable group (as specified by Itanium ABI), and `address_point.offset`
(address point index) the actual address point within that vtable.

The return type is always a `!cir.ptr<!cir.ptr<() -> i32>>`.
The return type is always `!cir.ptr<!cir.vptr>`.

Example:
```mlir
cir.global linkonce_odr @_ZTV1B = ...
...
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<() -> i32>>
%3 = cir.vtable.address_point(@_ZTV1B,
address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.vptr>
```
}];

Expand All @@ -2609,7 +2610,7 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
CIR_AddressPointAttr:$address_point
);

let results = (outs Res<CIR_PointerType, "", []>:$addr);
let results = (outs Res<CIR_PtrToVPtr, "", []>:$addr);

let assemblyFormat = [{
`(`
Expand All @@ -2624,6 +2625,101 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VTableGetVptr
//===----------------------------------------------------------------------===//

def CIR_VTableGetVptrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
let summary = "Get a the address of the vtable pointer for an object";
let description = [{
The `vtable.get_vptr` operation retrieves the address of the vptr for a
C++ object. This operation requires that the object pointer points to
the start of a complete object. (TODO: Describe how we get that).
The vptr will always be at offset zero in the object, but this operation
is more explicit about what is being retrieved than a direct bitcast.

The return type is always `!cir.ptr<!cir.vptr>`.

Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
```
}];

let arguments = (ins
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src);

let results = (outs CIR_PtrToVPtr:$vptr_ty);


let assemblyFormat = [{
$src `:` qualified(type($src)) `->` qualified(type($vptr_ty)) attr-dict
}];

}

//===----------------------------------------------------------------------===//
// VTableGetVirtualFnAddrOp
//===----------------------------------------------------------------------===//

def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
Pure
]> {
let summary = "Get a the address of a virtual function pointer";
let description = [{
The `vtable.get_virtual_fn_addr` operation retrieves the address of a
virtual function pointer from an object's vtable (__vptr).
This is an abstraction to perform the basic pointer arithmetic to get
the address of the virtual function pointer, which can then be loaded and
called.

The return type is a pointer-to-pointer to the function type.

Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
%5 = cir.vtable.get_virtual_fn_addr(%4, index = 2) : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
-> !s32i>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
!cir.ptr<!rec_C>) -> !s32i
```
}];

let arguments = (ins
Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
IndexAttr:$index_attr);

let results = (outs CIR_PointerType:$vfptr_ty);

let assemblyFormat = [{
`(`
$vptr `,` `index` `=` $index_attr
`)`
`:` qualified(type($vptr)) `,` qualified(type($vfptr_ty)) attr-dict
}];

let builders = [
OpBuilder<(ins "mlir::Type":$type,
"mlir::Value":$value,
"unsigned":$index),
[{
mlir::APInt fnIdx(64, index);
build($_builder, $_state, type, value, fnIdx);
}]>
];

let extraClassDeclaration = [{
/// Return the index of the record member being accessed.
uint64_t getIndex() { return getIndexAttr().getZExtValue(); }
}];
}

//===----------------------------------------------------------------------===//
// VTTAddrPointOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 10 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,22 @@ def CIR_PtrToExceptionInfoType
def CIR_AnyDataMemberType : CIR_TypeBase<"::cir::DataMemberType",
"data member type">;

//===----------------------------------------------------------------------===//
// VPtr type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType",
"vptr type">;

def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;

//===----------------------------------------------------------------------===//
// Scalar Type predicates
//===----------------------------------------------------------------------===//

defvar CIR_ScalarTypes = [
CIR_AnyBoolType, CIR_AnyIntType, CIR_AnyFloatType, CIR_AnyPtrType,
CIR_AnyDataMemberType
CIR_AnyDataMemberType, CIR_AnyVPtrType
];

def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
Expand Down
33 changes: 32 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,36 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
}];
}

//===----------------------------------------------------------------------===//
// CIR_VPtrType
//===----------------------------------------------------------------------===//

def CIR_VPtrType : CIR_Type<"VPtr", "vptr",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR type that is used for the vptr member of C++ objects";
let description = [{
`cir.vptr` is a special type used as the type for the vptr member of a C++
object. This avoids using arbitrary pointer types to declare vptr values
and allows stronger type-based checking for operations that use or provide
access to the vptr.

This type will be the element type of the 'vptr' member of structures that
require a vtable pointer. A pointer to this type is returned by the
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
get the address of a virtual function pointer.

The pointer may also be cast to other pointer types in order to perform
pointer arithmetic based on information encoded in the AST layout to get
the offset from a pointer to a dynamic object to the base object pointer,
the base object offset value from the vtable, or the type information
entry for an object.
TODO: We should have special operations to do that too.
}];
}


//===----------------------------------------------------------------------===//
// BoolType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -751,7 +781,8 @@ def CIRRecordType : Type<
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType,
CIR_VPtrType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
8 changes: 2 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
llvm_unreachable("unsupported long double format");
}

mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
// type so it's a bit more clear and C++ idiomatic.
auto fnTy = cir::FuncType::get({}, getUInt32Ty(), isVarArg);
assert(!cir::MissingFeatures::isVarArg());
return getPointerTo(getPointerTo(fnTy));
mlir::Type getPtrToVPtrType() {
return getPointerTo(cir::VPtrType::get(getContext()));
}

cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy,
Expand Down
8 changes: 5 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,10 +1704,12 @@ void CIRGenFunction::emitTypeMetadataCodeForVCall(const CXXRecordDecl *RD,
}

mlir::Value CIRGenFunction::getVTablePtr(mlir::Location Loc, Address This,
mlir::Type VTableTy,
const CXXRecordDecl *RD) {
Address VTablePtrSrc = builder.createElementBitCast(Loc, This, VTableTy);
auto VTable = builder.createLoad(Loc, VTablePtrSrc);
auto VTablePtr = builder.create<cir::VTableGetVptrOp>(
Loc, builder.getPtrToVPtrType(), This.getPointer());
Address VTablePtrAddr = Address(VTablePtr, This.getAlignment());

auto VTable = builder.createLoad(Loc, VTablePtrAddr);
assert(!cir::MissingFeatures::tbaa());

if (CGM.getCodeGenOpts().OptimizationLevel > 0 &&
Expand Down
1 change: 0 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ class CIRGenFunction : public CIRGenTypeCache {
VisitedVirtualBasesSetTy &VBases, VPtrsVector &vptrs);
/// Return the Value of the vtable pointer member pointed to by This.
mlir::Value getVTablePtr(mlir::Location Loc, Address This,
mlir::Type VTableTy,
const CXXRecordDecl *VTableClass);

/// Returns whether we should perform a type checked load when loading a
Expand Down
21 changes: 9 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,7 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
auto loc = CGF.getLoc(Loc);
auto TyPtr = CGF.getBuilder().getPointerTo(Ty);
auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl());
auto VTable = CGF.getVTablePtr(
loc, This, CGF.getBuilder().getPointerTo(TyPtr), MethodDecl->getParent());
auto VTable = CGF.getVTablePtr(loc, This, MethodDecl->getParent());

uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD);
mlir::Value VFunc{};
Expand All @@ -945,13 +944,9 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
if (CGM.getItaniumVTableContext().isRelativeLayout()) {
llvm_unreachable("NYI");
} else {
VTable = CGF.getBuilder().createBitcast(
loc, VTable, CGF.getBuilder().getPointerTo(TyPtr));
auto VTableSlotPtr = CGF.getBuilder().create<cir::VTableAddrPointOp>(
loc, CGF.getBuilder().getPointerTo(TyPtr),
::mlir::FlatSymbolRefAttr{}, VTable,
cir::AddressPointAttr::get(CGF.getBuilder().getContext(), 0,
VTableIndex));
auto VTableSlotPtr =
CGF.getBuilder().create<cir::VTableGetVirtualFnAddrOp>(
loc, CGF.getBuilder().getPointerTo(TyPtr), VTable, VTableIndex);
VFuncLoad = CGF.getBuilder().createAlignedLoad(loc, TyPtr, VTableSlotPtr,
CGF.getPointerAlign());
}
Expand Down Expand Up @@ -1007,7 +1002,7 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
.getAddressPoint(Base);

auto &builder = CGM.getBuilder();
auto vtablePtrTy = builder.getVirtualFnPtrType(/*isVarArg=*/false);
auto vtablePtrTy = builder.getPtrToVPtrType();

return builder.create<cir::VTableAddrPointOp>(
CGM.getLoc(VTableClass->getSourceRange()), vtablePtrTy,
Expand Down Expand Up @@ -2377,14 +2372,16 @@ void CIRGenItaniumCXXABI::emitThrow(CIRGenFunction &CGF,
mlir::Value CIRGenItaniumCXXABI::getVirtualBaseClassOffset(
mlir::Location loc, CIRGenFunction &CGF, Address This,
const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl) {
auto VTablePtr = CGF.getVTablePtr(loc, This, CGM.UInt8PtrTy, ClassDecl);
auto VTablePtr = CGF.getVTablePtr(loc, This, ClassDecl);
auto VTableBytePtr =
CGF.getBuilder().createBitcast(VTablePtr, CGM.UInt8PtrTy);
CharUnits VBaseOffsetOffset =
CGM.getItaniumVTableContext().getVirtualBaseOffsetOffset(ClassDecl,
BaseClassDecl);
mlir::Value OffsetVal =
CGF.getBuilder().getSInt64(VBaseOffsetOffset.getQuantity(), loc);
auto VBaseOffsetPtr = CGF.getBuilder().create<cir::PtrStrideOp>(
loc, VTablePtr.getType(), VTablePtr,
loc, CGM.UInt8PtrTy, VTableBytePtr,
OffsetVal); // vbase.offset.ptr

mlir::Value VBaseOffset;
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,7 @@ void CIRRecordLowering::accumulateVPtrs() {
}

mlir::Type CIRRecordLowering::getVFPtrType() {
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
// type so it's a bit more clear and C++ idiomatic.
return builder.getVirtualFnPtrType();
return cir::VPtrType::get(builder.getContext());
}

void CIRRecordLowering::fillOutputFields() {
Expand Down
11 changes: 7 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ LogicalResult cir::CastOp::verify() {
return success();
}

// Allow casting cir.vptr to pointer types.
// TODO: Add operations to get object offset and type info and remove this.
if (mlir::isa<cir::VPtrType>(srcType) &&
mlir::dyn_cast<cir::PointerType>(resType))
return success();

// Handle the data member pointer types.
if (mlir::isa<cir::DataMemberType>(srcType) &&
mlir::isa<cir::DataMemberType>(resType))
Expand Down Expand Up @@ -2423,10 +2429,7 @@ LogicalResult cir::VTableAddrPointOp::verify() {
return success();

auto resultType = getAddr().getType();
auto intTy = cir::IntType::get(getContext(), 32, /*isSigned=*/false);
auto fnTy = cir::FuncType::get({}, intTy);

auto resTy = cir::PointerType::get(cir::PointerType::get(fnTy));
auto resTy = cir::PointerType::get(cir::VPtrType::get(getContext()));

if (resultType != resTy)
return emitOpError("result type must be '")
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,20 @@ DataMemberType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
return 8;
}

llvm::TypeSize
VPtrType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
// FIXME: consider size differences under different ABIs
return llvm::TypeSize::getFixed(64);
}

uint64_t
VPtrType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
// FIXME: consider alignment differences under different ABIs
return 8;
}

llvm::TypeSize
ArrayType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,17 @@ buildDynamicCastToVoidAfterNullCheck(CIRBaseBuilderTy &builder,

// Access vtable to get the offset from the given object to its containing
// complete object.
auto vtablePtrTy = builder.getPointerTo(vtableElemTy);
auto vtablePtrPtr =
builder.createBitcast(op.getSrc(), builder.getPointerTo(vtablePtrTy));
auto vtablePtr = builder.createLoad(loc, vtablePtrPtr);
auto offsetToTopSlotPtr = builder.create<cir::VTableAddrPointOp>(
loc, vtablePtrTy, mlir::FlatSymbolRefAttr{}, vtablePtr,
cir::AddressPointAttr::get(builder.getContext(), 0, -2));
// TODO: Add a specialized operation to get the object offset?
auto vptrTy = cir::VPtrType::get(builder.getContext());
auto vptrPtrTy = builder.getPointerTo(vptrTy);
auto vptrPtr =
builder.create<cir::VTableGetVptrOp>(loc, vptrPtrTy, op.getSrc());
auto vptr = builder.createLoad(loc, vptrPtr);
auto elementPtr =
builder.createBitcast(vptr, builder.getPointerTo(vtableElemTy));
auto minusTwo = builder.getSignedInt(loc, -2, 64);
auto offsetToTopSlotPtr = builder.create<cir::PtrStrideOp>(
loc, builder.getPointerTo(vtableElemTy), elementPtr, minusTwo);
auto offsetToTop =
builder.createAlignedLoad(loc, offsetToTopSlotPtr, vtableElemAlign);

Expand Down
Loading
Loading