@@ -123,6 +123,10 @@ class Util {
123123 // / specialization id class.
124124 static bool isSyclSpecIdType (QualType Ty);
125125
126+ // / Checks whether given clang type is a full specialization of the SYCL
127+ // / device_global class.
128+ static bool isSyclDeviceGlobalType (QualType Ty);
129+
126130 // / Checks whether given clang type is a full specialization of the SYCL
127131 // / kernel_handler class.
128132 static bool isSyclKernelHandlerType (QualType Ty);
@@ -4692,7 +4696,23 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
46924696 O << " namespace sycl {\n " ;
46934697 O << " namespace detail {\n " ;
46944698
4695- O << " \n " ;
4699+ // Generate declaration of variable of type __sycl_device_global_registration
4700+ // whose sole purpose is to run its constructor before the application's
4701+ // main() function.
4702+
4703+ if (S.getSyclIntegrationFooter ().isDeviceGlobalsEmitted ()) {
4704+ O << " namespace {\n " ;
4705+
4706+ O << " class __sycl_device_global_registration {\n " ;
4707+ O << " public:\n " ;
4708+ O << " __sycl_device_global_registration() noexcept;\n " ;
4709+ O << " };\n " ;
4710+ O << " __sycl_device_global_registration __sycl_device_global_registrar;\n " ;
4711+
4712+ O << " } // namespace\n " ;
4713+
4714+ O << " \n " ;
4715+ }
46964716
46974717 O << " // names of all kernels defined in the corresponding source\n " ;
46984718 O << " static constexpr\n " ;
@@ -4874,9 +4894,9 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
48744894 // template instantiations as a VarDecl.
48754895 if (isa<VarTemplatePartialSpecializationDecl>(VD))
48764896 return ;
4877- // Step 1: ensure that this is of the correct type-spec-constant template
4878- // specialization).
4879- if ( !Util::isSyclSpecIdType (VD->getType ())) {
4897+ // Step 1: ensure that this is of the correct type template specialization.
4898+ if (! Util::isSyclSpecIdType (VD-> getType ()) &&
4899+ !Util::isSyclDeviceGlobalType (VD->getType ())) {
48804900 // Handle the case where this could be a deduced type, such as a deduction
48814901 // guide. We have to do this here since this function, unlike most of the
48824902 // rest of this file, is called during Sema instead of after it. We will
@@ -4892,8 +4912,8 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
48924912 // let an error happen during host compilation.
48934913 if (!VD->hasGlobalStorage () || VD->isLocalVarDeclOrParm ())
48944914 return ;
4895- // Step 3: Add to SpecConstants collection.
4896- SpecConstants .push_back (VD);
4915+ // Step 3: Add to collection.
4916+ GlobalVars .push_back (VD);
48974917}
48984918
48994919// Post-compile integration header support.
@@ -4967,29 +4987,28 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
49674987 [](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
49684988}
49694989
4970- static std::string EmitSpecIdShim (raw_ostream &OS, unsigned &ShimCounter,
4971- const std::string &LastShim,
4972- const NamespaceDecl *AnonNS) {
4990+ static std::string EmitShim (raw_ostream &OS, unsigned &ShimCounter,
4991+ const std::string &LastShim,
4992+ const NamespaceDecl *AnonNS) {
49734993 std::string NewShimName =
4974- " __sycl_detail::__spec_id_shim_ " + std::to_string (ShimCounter) + " ()" ;
4994+ " __sycl_detail::__shim_ " + std::to_string (ShimCounter) + " ()" ;
49754995 // Print opening-namespace
49764996 PrintNamespaces (OS, Decl::castToDeclContext (AnonNS));
49774997 OS << " namespace __sycl_detail {\n " ;
4978- OS << " static constexpr decltype(" << LastShim << " ) &__spec_id_shim_ "
4979- << ShimCounter << " () {\n " ;
4998+ OS << " static constexpr decltype(" << LastShim << " ) &__shim_ " << ShimCounter
4999+ << " () {\n " ;
49805000 OS << " return " << LastShim << " ;\n " ;
49815001 OS << " }\n " ;
4982- OS << " } // namespace __sycl_detail \n " ;
5002+ OS << " } // namespace __sycl_detail\n " ;
49835003 PrintNSClosingBraces (OS, Decl::castToDeclContext (AnonNS));
49845004
49855005 ++ShimCounter;
49865006 return NewShimName;
49875007}
49885008
49895009// Emit the list of shims required for a DeclContext, calls itself recursively.
4990- static void EmitSpecIdShims (raw_ostream &OS, unsigned &ShimCounter,
4991- const DeclContext *DC,
4992- std::string &NameForLastShim) {
5010+ static void EmitShims (raw_ostream &OS, unsigned &ShimCounter,
5011+ const DeclContext *DC, std::string &NameForLastShim) {
49935012 if (DC->isTranslationUnit ()) {
49945013 NameForLastShim = " ::" + NameForLastShim;
49955014 return ;
@@ -5003,7 +5022,7 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
50035022 } else if (const auto *ND = dyn_cast<NamespaceDecl>(CurDecl)) {
50045023 if (ND->isAnonymousNamespace ()) {
50055024 // Print current shim, reset 'name for last shim'.
5006- NameForLastShim = EmitSpecIdShim (OS, ShimCounter, NameForLastShim, ND);
5025+ NameForLastShim = EmitShim (OS, ShimCounter, NameForLastShim, ND);
50075026 } else {
50085027 NameForLastShim = ND->getNameAsString () + " ::" + NameForLastShim;
50095028 }
@@ -5017,22 +5036,22 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
50175036 " Unhandled decl type" );
50185037 }
50195038
5020- EmitSpecIdShims (OS, ShimCounter, CurDecl->getDeclContext (), NameForLastShim);
5039+ EmitShims (OS, ShimCounter, CurDecl->getDeclContext (), NameForLastShim);
50215040}
50225041
50235042// Emit the list of shims required for a variable declaration.
50245043// Returns a string containing the FQN of the 'top most' shim, including its
50255044// function call parameters.
5026- static std::string EmitSpecIdShims (raw_ostream &OS, unsigned &ShimCounter,
5027- PrintingPolicy &Policy, const VarDecl *VD) {
5045+ static std::string EmitShims (raw_ostream &OS, unsigned &ShimCounter,
5046+ PrintingPolicy &Policy, const VarDecl *VD) {
50285047 if (!VD->isInAnonymousNamespace ())
50295048 return " " ;
50305049 std::string RelativeName;
50315050 llvm::raw_string_ostream stream (RelativeName);
50325051 VD->getNameForDiagnostic (stream, Policy, false );
50335052 stream.flush ();
50345053
5035- EmitSpecIdShims (OS, ShimCounter, VD->getDeclContext (), RelativeName);
5054+ EmitShims (OS, ShimCounter, VD->getDeclContext (), RelativeName);
50365055 return RelativeName;
50375056}
50385057
@@ -5042,58 +5061,90 @@ bool SYCLIntegrationFooter::emit(raw_ostream &OS) {
50425061 Policy.SuppressTypedefs = true ;
50435062 Policy.SuppressUnwrittenScope = true ;
50445063
5045- llvm::SmallSet<const VarDecl *, 8 > VisitedSpecConstants ;
5064+ llvm::SmallSet<const VarDecl *, 8 > Visited ;
50465065 bool EmittedFirstSpecConstant = false ;
50475066
50485067 // Used to uniquely name the 'shim's as we generate the names in each
50495068 // anonymous namespace.
50505069 unsigned ShimCounter = 0 ;
5051- for (const VarDecl *VD : SpecConstants) {
5070+
5071+ std::string DeviceGlobalsBuf;
5072+ llvm::raw_string_ostream DeviceGlobOS (DeviceGlobalsBuf);
5073+ for (const VarDecl *VD : GlobalVars) {
50525074 VD = VD->getCanonicalDecl ();
50535075
5054- // Skip if this isn't a SpecIdType. This can happen if it was a deduced
5055- // type.
5056- if (!Util::isSyclSpecIdType (VD->getType ()))
5076+ // Skip if this isn't a SpecIdType or DeviceGlobal. This can happen if it
5077+ // was a deduced type.
5078+ if (!Util::isSyclSpecIdType (VD->getType ()) &&
5079+ !Util::isSyclDeviceGlobalType (VD->getType ()))
50575080 continue ;
50585081
50595082 // Skip if we've already visited this.
5060- if (llvm::find (VisitedSpecConstants , VD) != VisitedSpecConstants .end ())
5083+ if (llvm::find (Visited , VD) != Visited .end ())
50615084 continue ;
50625085
5063- // We only want to emit the #includes if we have a spec-constant that needs
5086+ // We only want to emit the #includes if we have a variable that needs
50645087 // them, so emit this one on the first time through the loop.
5065- if (!EmittedFirstSpecConstant)
5088+ if (!EmittedFirstSpecConstant && !DeviceGlobalsEmitted )
50665089 OS << " #include <CL/sycl/detail/defines_elementary.hpp>\n " ;
5067- EmittedFirstSpecConstant = true ;
5068-
5069- VisitedSpecConstants.insert (VD);
5070- std::string TopShim = EmitSpecIdShims (OS, ShimCounter, Policy, VD);
5071- OS << " __SYCL_INLINE_NAMESPACE(cl) {\n " ;
5072- OS << " namespace sycl {\n " ;
5073- OS << " namespace detail {\n " ;
5074- OS << " template<>\n " ;
5075- OS << " inline const char *get_spec_constant_symbolic_ID_impl<" ;
5076-
5077- if (VD->isInAnonymousNamespace ()) {
5078- OS << TopShim;
5090+
5091+ Visited.insert (VD);
5092+ std::string TopShim = EmitShims (OS, ShimCounter, Policy, VD);
5093+ if (Util::isSyclDeviceGlobalType (VD->getType ())) {
5094+ DeviceGlobalsEmitted = true ;
5095+ DeviceGlobOS << " device_global_map::add(" ;
5096+ DeviceGlobOS << " (void *)&" ;
5097+ if (VD->isInAnonymousNamespace ()) {
5098+ DeviceGlobOS << TopShim;
5099+ } else {
5100+ DeviceGlobOS << " ::" ;
5101+ VD->getNameForDiagnostic (DeviceGlobOS, Policy, true );
5102+ }
5103+ DeviceGlobOS << " , \" " ;
5104+ DeviceGlobOS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (),
5105+ VD);
5106+ DeviceGlobOS << " \" );\n " ;
50795107 } else {
5080- OS << " ::" ;
5081- VD->getNameForDiagnostic (OS, Policy, true );
5082- }
5108+ EmittedFirstSpecConstant = true ;
5109+ OS << " __SYCL_INLINE_NAMESPACE(cl) {\n " ;
5110+ OS << " namespace sycl {\n " ;
5111+ OS << " namespace detail {\n " ;
5112+ OS << " template<>\n " ;
5113+ OS << " inline const char *get_spec_constant_symbolic_ID_impl<" ;
5114+
5115+ if (VD->isInAnonymousNamespace ()) {
5116+ OS << TopShim;
5117+ } else {
5118+ OS << " ::" ;
5119+ VD->getNameForDiagnostic (OS, Policy, true );
5120+ }
50835121
5084- OS << " >() {\n " ;
5085- OS << " return \" " ;
5086- OS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (), VD);
5087- OS << " \" ;\n " ;
5088- OS << " }\n " ;
5089- OS << " } // namespace detail\n " ;
5090- OS << " } // namespace sycl\n " ;
5091- OS << " } // __SYCL_INLINE_NAMESPACE(cl)\n " ;
5122+ OS << " >() {\n " ;
5123+ OS << " return \" " ;
5124+ OS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (), VD);
5125+ OS << " \" ;\n " ;
5126+ OS << " }\n " ;
5127+ OS << " } // namespace detail\n " ;
5128+ OS << " } // namespace sycl\n " ;
5129+ OS << " } // __SYCL_INLINE_NAMESPACE(cl)\n " ;
5130+ }
50925131 }
50935132
50945133 if (EmittedFirstSpecConstant)
50955134 OS << " #include <CL/sycl/detail/spec_const_integration.hpp>\n " ;
50965135
5136+ if (DeviceGlobalsEmitted) {
5137+ OS << " #include <CL/sycl/detail/device_global_map.hpp>\n " ;
5138+ DeviceGlobOS.flush ();
5139+ OS << " namespace sycl::detail {\n " ;
5140+ OS << " namespace {\n " ;
5141+ OS << " __sycl_device_global_registration::__sycl_device_global_"
5142+ " registration() noexcept {\n " ;
5143+ OS << DeviceGlobalsBuf;
5144+ OS << " }\n " ;
5145+ OS << " } // namespace (unnamed)\n " ;
5146+ OS << " } // namespace sycl::detail\n " ;
5147+ }
50975148 return true ;
50985149}
50995150
@@ -5138,6 +5189,18 @@ bool Util::isSyclSpecIdType(QualType Ty) {
51385189 return matchQualifiedTypeName (Ty, Scopes);
51395190}
51405191
5192+ bool Util::isSyclDeviceGlobalType (QualType Ty) {
5193+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
5194+ if (!RecTy)
5195+ return false ;
5196+ if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RecTy)) {
5197+ ClassTemplateDecl *Template = CTSD->getSpecializedTemplate ();
5198+ if (CXXRecordDecl *RD = Template->getTemplatedDecl ())
5199+ return RD->hasAttr <SYCLDeviceGlobalAttr>();
5200+ }
5201+ return RecTy->hasAttr <SYCLDeviceGlobalAttr>();
5202+ }
5203+
51415204bool Util::isSyclKernelHandlerType (QualType Ty) {
51425205 std::array<DeclContextDesc, 3 > Scopes = {
51435206 Util::MakeDeclContextDesc (Decl::Kind::Namespace, " cl" ),
0 commit comments