Skip to content

Commit fe7a761

Browse files
committed
[OpenMP][SPIR-V] Fix addrspace of pointer kernel arguments
Signed-off-by: Sarnie, Nick <[email protected]>
1 parent daf8f9f commit fe7a761

File tree

8 files changed

+64
-31
lines changed

8 files changed

+64
-31
lines changed

clang/lib/CodeGen/CGCall.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,8 @@ const CGFunctionInfo &CodeGenTypes::arrangeBuiltinFunctionDeclaration(
752752
RequiredArgs::All);
753753
}
754754

755-
const CGFunctionInfo &
756-
CodeGenTypes::arrangeSYCLKernelCallerDeclaration(QualType resultType,
757-
const FunctionArgList &args) {
755+
const CGFunctionInfo &CodeGenTypes::arrangeDeviceKernelCallerDeclaration(
756+
QualType resultType, const FunctionArgList &args) {
758757
CanQualTypeList argTypes = getArgTypesForDeclaration(Context, args);
759758

760759
return arrangeLLVMFunctionInfo(GetReturnType(resultType), FnInfoOpts::None,

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ static llvm::Function *emitParallelOrTeamsOutlinedFunction(
12381238
CGOpenMPOutlinedRegionInfo CGInfo(*CS, ThreadIDVar, CodeGen, InnermostKind,
12391239
HasCancel, OutlinedHelperName);
12401240
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
1241-
return CGF.GenerateOpenMPCapturedStmtFunction(*CS, D.getBeginLoc());
1241+
return CGF.GenerateOpenMPCapturedStmtFunction(*CS, D);
12421242
}
12431243

12441244
std::string CGOpenMPRuntime::getOutlinedHelperName(StringRef Name) const {
@@ -6227,7 +6227,7 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
62276227

62286228
CGOpenMPTargetRegionInfo CGInfo(CS, CodeGen, EntryFnName);
62296229
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
6230-
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D.getBeginLoc());
6230+
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D);
62316231
};
62326232

62336233
cantFail(OMPBuilder.emitTargetRegionFunction(

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,13 @@ struct FunctionOptions {
471471
const StringRef FunctionName;
472472
/// Location of the non-debug version of the outlined function.
473473
SourceLocation Loc;
474+
const bool IsDeviceKernel = false;
474475
explicit FunctionOptions(const CapturedStmt *S, bool UIntPtrCastRequired,
475476
bool RegisterCastedArgsOnly, StringRef FunctionName,
476-
SourceLocation Loc)
477+
SourceLocation Loc, bool IsDeviceKernel)
477478
: S(S), UIntPtrCastRequired(UIntPtrCastRequired),
478479
RegisterCastedArgsOnly(UIntPtrCastRequired && RegisterCastedArgsOnly),
479-
FunctionName(FunctionName), Loc(Loc) {}
480+
FunctionName(FunctionName), Loc(Loc), IsDeviceKernel(IsDeviceKernel) {}
480481
};
481482
} // namespace
482483

@@ -570,7 +571,11 @@ static llvm::Function *emitOutlinedFunctionPrologue(
570571

571572
// Create the function declaration.
572573
const CGFunctionInfo &FuncInfo =
573-
CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, TargetArgs);
574+
FO.IsDeviceKernel
575+
? CGM.getTypes().arrangeDeviceKernelCallerDeclaration(Ctx.VoidTy,
576+
TargetArgs)
577+
: CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy,
578+
TargetArgs);
574579
llvm::FunctionType *FuncLLVMTy = CGM.getTypes().GetFunctionType(FuncInfo);
575580

576581
auto *F =
@@ -664,9 +669,9 @@ static llvm::Function *emitOutlinedFunctionPrologue(
664669
return F;
665670
}
666671

667-
llvm::Function *
668-
CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
669-
SourceLocation Loc) {
672+
llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
673+
const CapturedStmt &S, const OMPExecutableDirective &D) {
674+
SourceLocation Loc = D.getBeginLoc();
670675
assert(
671676
CapturedStmtInfo &&
672677
"CapturedStmtInfo should be set when generating the captured function");
@@ -682,23 +687,27 @@ CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
682687
SmallString<256> Buffer;
683688
llvm::raw_svector_ostream Out(Buffer);
684689
Out << CapturedStmtInfo->getHelperName();
685-
690+
OpenMPDirectiveKind EKind = getEffectiveDirectiveKind(D);
691+
bool IsDeviceKernel = CGM.getOpenMPRuntime().isGPU() &&
692+
isOpenMPTargetExecutionDirective(EKind) &&
693+
D.getCapturedStmt(OMPD_target) == &S;
686694
CodeGenFunction WrapperCGF(CGM, /*suppressNewContext=*/true);
687695
llvm::Function *WrapperF = nullptr;
688696
if (NeedWrapperFunction) {
689697
// Emit the final kernel early to allow attributes to be added by the
690698
// OpenMPI-IR-Builder.
691699
FunctionOptions WrapperFO(&S, /*UIntPtrCastRequired=*/true,
692700
/*RegisterCastedArgsOnly=*/true,
693-
CapturedStmtInfo->getHelperName(), Loc);
701+
CapturedStmtInfo->getHelperName(), Loc,
702+
IsDeviceKernel);
694703
WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
695704
WrapperF =
696705
emitOutlinedFunctionPrologue(WrapperCGF, Args, LocalAddrs, VLASizes,
697706
WrapperCGF.CXXThisValue, WrapperFO);
698707
Out << "_debug__";
699708
}
700709
FunctionOptions FO(&S, !NeedWrapperFunction, /*RegisterCastedArgsOnly=*/false,
701-
Out.str(), Loc);
710+
Out.str(), Loc, !NeedWrapperFunction && IsDeviceKernel);
702711
llvm::Function *F = emitOutlinedFunctionPrologue(
703712
*this, WrapperArgs, WrapperLocalAddrs, WrapperVLASizes, CXXThisValue, FO);
704713
CodeGenFunction::OMPPrivateScope LocalScope(*this);
@@ -6119,13 +6128,13 @@ void CodeGenFunction::EmitOMPDistributeDirective(
61196128
emitOMPDistributeDirective(S, *this, CGM);
61206129
}
61216130

6122-
static llvm::Function *emitOutlinedOrderedFunction(CodeGenModule &CGM,
6123-
const CapturedStmt *S,
6124-
SourceLocation Loc) {
6131+
static llvm::Function *
6132+
emitOutlinedOrderedFunction(CodeGenModule &CGM, const CapturedStmt *S,
6133+
const OMPExecutableDirective &D) {
61256134
CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
61266135
CodeGenFunction::CGCapturedStmtInfo CapStmtInfo;
61276136
CGF.CapturedStmtInfo = &CapStmtInfo;
6128-
llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction(*S, Loc);
6137+
llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction(*S, D);
61296138
Fn->setDoesNotRecurse();
61306139
return Fn;
61316140
}
@@ -6190,8 +6199,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
61906199
Builder, /*CreateBranch=*/false, ".ordered.after");
61916200
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
61926201
GenerateOpenMPCapturedVars(*CS, CapturedVars);
6193-
llvm::Function *OutlinedFn =
6194-
emitOutlinedOrderedFunction(CGM, CS, S.getBeginLoc());
6202+
llvm::Function *OutlinedFn = emitOutlinedOrderedFunction(CGM, CS, S);
61956203
assert(S.getBeginLoc().isValid() &&
61966204
"Outlined function call location must be valid.");
61976205
ApplyDebugLocation::CreateDefaultArtificial(*this, S.getBeginLoc());
@@ -6233,8 +6241,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
62336241
if (C) {
62346242
llvm::SmallVector<llvm::Value *, 16> CapturedVars;
62356243
CGF.GenerateOpenMPCapturedVars(*CS, CapturedVars);
6236-
llvm::Function *OutlinedFn =
6237-
emitOutlinedOrderedFunction(CGM, CS, S.getBeginLoc());
6244+
llvm::Function *OutlinedFn = emitOutlinedOrderedFunction(CGM, CS, S);
62386245
CGM.getOpenMPRuntime().emitOutlinedFunctionCall(CGF, S.getBeginLoc(),
62396246
OutlinedFn, CapturedVars);
62406247
} else {

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3694,8 +3694,9 @@ class CodeGenFunction : public CodeGenTypeCache {
36943694
llvm::Function *EmitCapturedStmt(const CapturedStmt &S, CapturedRegionKind K);
36953695
llvm::Function *GenerateCapturedStmtFunction(const CapturedStmt &S);
36963696
Address GenerateCapturedStmtArgument(const CapturedStmt &S);
3697-
llvm::Function *GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
3698-
SourceLocation Loc);
3697+
llvm::Function *
3698+
GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
3699+
const OMPExecutableDirective &D);
36993700
void GenerateOpenMPCapturedVars(const CapturedStmt &S,
37003701
SmallVectorImpl<llvm::Value *> &CapturedVars);
37013702
void emitOMPSimpleStore(LValue LVal, RValue RVal, QualType RValTy,

clang/lib/CodeGen/CodeGenSYCL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void CodeGenModule::EmitSYCLKernelCaller(const FunctionDecl *KernelEntryPointFn,
4949

5050
// Compute the function info and LLVM function type.
5151
const CGFunctionInfo &FnInfo =
52-
getTypes().arrangeSYCLKernelCallerDeclaration(Ctx.VoidTy, Args);
52+
getTypes().arrangeDeviceKernelCallerDeclaration(Ctx.VoidTy, Args);
5353
llvm::FunctionType *FnTy = getTypes().GetFunctionType(FnInfo);
5454

5555
// Retrieve the generated name for the SYCL kernel caller function.

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,12 @@ class CodeGenTypes {
229229
const CGFunctionInfo &arrangeBuiltinFunctionCall(QualType resultType,
230230
const CallArgList &args);
231231

232-
/// A SYCL kernel caller function is an offload device entry point function
232+
/// A device kernel caller function is an offload device entry point function
233233
/// with a target device dependent calling convention such as amdgpu_kernel,
234234
/// ptx_kernel, or spir_kernel.
235235
const CGFunctionInfo &
236-
arrangeSYCLKernelCallerDeclaration(QualType resultType,
237-
const FunctionArgList &args);
236+
arrangeDeviceKernelCallerDeclaration(QualType resultType,
237+
const FunctionArgList &args);
238238

239239
/// Objective-C methods are C functions with some implicit parameters.
240240
const CGFunctionInfo &arrangeObjCMethodDeclaration(const ObjCMethodDecl *MD);

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,12 @@ ABIArgInfo SPIRVABIInfo::classifyReturnType(QualType RetTy) const {
132132
}
133133

134134
ABIArgInfo SPIRVABIInfo::classifyKernelArgumentType(QualType Ty) const {
135-
if (getContext().getLangOpts().CUDAIsDevice) {
135+
if (getContext().getLangOpts().CUDAIsDevice ||
136+
getContext().getLangOpts().OpenMPIsTargetDevice) {
136137
// Coerce pointer arguments with default address space to CrossWorkGroup
137-
// pointers for HIPSPV/CUDASPV. When the language mode is HIP/CUDA, the
138-
// SPIRTargetInfo maps cuda_device to SPIR-V's CrossWorkGroup address space.
138+
// pointers for HIPSPV/CUDASPV/OMPSPV. When the language mode is
139+
// HIP/CUDA/OMP, the SPIRTargetInfo maps cuda_device to SPIR-V's
140+
// CrossWorkGroup address space.
139141
llvm::Type *LTy = CGT.ConvertType(Ty);
140142
auto DefaultAS = getContext().getTargetAddressSpace(LangAS::Default);
141143
auto GlobalAS = getContext().getTargetAddressSpace(LangAS::cuda_device);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple x86_64-unknown-linux -fopenmp-targets=spirv64-intel -emit-llvm-bc %s -o %t-host.bc
2+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple spirv64-intel -fopenmp-targets=spirv64-intel -emit-llvm %s -fopenmp-is-target-device -fopenmp-host-ir-file-path %t-host.bc -o - | FileCheck %s
3+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple x86_64-unknown-linux -fopenmp-targets=spirv64-intel -emit-llvm-bc %s -o %t-host.bc -DTEAMS
4+
// RUN: %clang_cc1 -verify -fopenmp -x c++ -triple spirv64-intel -fopenmp-targets=spirv64-intel -emit-llvm %s -fopenmp-is-target-device -fopenmp-host-ir-file-path %t-host.bc -DTEAMS -o - | FileCheck %s
5+
// expected-no-diagnostics
6+
7+
// CHECK: define weak_odr protected spir_kernel void @__omp_offloading_{{.*}}(ptr addrspace(1) noalias noundef %{{.*}}, ptr addrspace(1) noundef align 4 dereferenceable(128) %{{.*}})
8+
9+
int main() {
10+
int x[32] = {0};
11+
12+
#ifdef TEAMS
13+
#pragma omp target teams
14+
#else
15+
#pragma omp target
16+
#endif
17+
for(int i = 0; i < 32; i++) {
18+
if(i > 0)
19+
x[i] = x[i-1] + i;
20+
}
21+
22+
return x[31];
23+
}
24+

0 commit comments

Comments
 (0)