Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/LangOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ LANGOPT(
LANGOPT(SYCLDisableRangeRounding, 1, 0, "Disable parallel for range rounding")
LANGOPT(SYCLEnableIntHeaderDiags, 1, 0, "Enable diagnostics that require the "
"SYCL integration header")
LANGOPT(SYCLAllowVirtualFunctions, 1, 0,
"Allow virtual functions calls in code for SYCL device")

LANGOPT(HIPUseNewLaunchAPI, 1, 0, "Use new kernel launching API for HIP")

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -6584,6 +6584,9 @@ def fsycl_use_main_file_name : Flag<["-"], "fsycl-use-main-file-name">,
HelpText<"Tells compiler that -main-file-name contains an absolute path and "
"file specified there should be used for checksum calculation.">,
MarshallingInfoFlag<CodeGenOpts<"SYCLUseMainFileName">>;
def fsycl_allow_virtual_functions : Flag<["-"], "fsycl-allow-virtual-functions">,
HelpText<"Allow virtual functions calls in code for SYCL device">,
MarshallingInfoFlag<LangOpts<"SYCLAllowVirtualFunctions">>;

} // let Flags = [CC1Option, NoDriverOption]

Expand Down
37 changes: 31 additions & 6 deletions clang/lib/CodeGen/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3565,8 +3565,33 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
// Check if the alias exists. If it doesn't, then get or create the global.
if (CGM.getItaniumVTableContext().isRelativeLayout())
VTable = CGM.getModule().getNamedAlias(VTableName);
if (!VTable)
VTable = CGM.getModule().getOrInsertGlobal(VTableName, CGM.Int8PtrTy);

// To generate valid device code global pointers should have global address
// space in SYCL.
bool GenTyInfoGVWithGlobalAS =
CGM.getLangOpts().SYCLIsDevice &&
CGM.getLangOpts().SYCLAllowVirtualFunctions &&
(VTableName == ClassTypeInfo || VTableName == SIClassTypeInfo);
auto VTableTy =
GenTyInfoGVWithGlobalAS
? CGM.Int8Ty->getPointerTo(
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global))
: CGM.Int8PtrTy;
if (!VTable) {
if (GenTyInfoGVWithGlobalAS) {
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy, [&] {
return new llvm::GlobalVariable(
CGM.getModule(), VTableTy, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage, /*Initializer=*/nullptr,
VTableName, /*InsertBefore=*/nullptr,
llvm::GlobalValue::ThreadLocalMode::NotThreadLocal,
llvm::Optional<unsigned>(
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global)));
});
} else {
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy);
}
}

CGM.setDSOLocal(cast<llvm::GlobalValue>(VTable->stripPointerCasts()));

Expand All @@ -3578,15 +3603,15 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
// The vtable address point is 8 bytes after its start:
// 4 for the offset to top + 4 for the relative offset to rtti.
llvm::Constant *Eight = llvm::ConstantInt::get(CGM.Int32Ty, 8);
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);
VTable =
llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8Ty, VTable, Eight);
} else {
llvm::Constant *Two = llvm::ConstantInt::get(PtrDiffTy, 2);
VTable = llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8PtrTy, VTable,
Two);
VTable =
llvm::ConstantExpr::getInBoundsGetElementPtr(VTableTy, VTable, Two);
}
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);

Fields.push_back(VTable);
}
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
}

if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
if (Method->isVirtual())
if (Method->isVirtual() &&
!SemaRef.getLangOpts().SYCLAllowVirtualFunctions)
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
<< Sema::KernelCallVirtualFunction;

Expand Down
33 changes: 33 additions & 0 deletions clang/test/CodeGenSYCL/simple-sycl-virtual-function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// This test checks that the FE generates global variables corresponding to the
// virtual table in the global address space (addrspace(1)) when
// -fsycl-allow-virtual-functions is passed.

// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-PTR
// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -fexperimental-relative-c++-abi-vtables -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-REL

// CHECK: @_ZTVN10__cxxabiv120__si_class_type_infoE = external addrspace(1) global ptr addrspace(1)
// CHECK: @_ZTVN10__cxxabiv117__class_type_infoE = external addrspace(1) global ptr addrspace(1)
// CHECK-PTR: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i64 2)
// CHECK-PTR: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i64 2)
// CHECK-REL: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i32 8)
// CHECK-REL: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i32 8)

SYCL_EXTERNAL bool rand();

class Base {
public:
virtual void display() {}
};

class Derived1 : public Base {
public:
void display() {}
};

SYCL_EXTERNAL void test() {
Derived1 d1;
Base *b = nullptr;
if (rand())
b = &d1;
b->display();
}