@@ -1638,6 +1638,18 @@ static std::string eraseAnonNamespace(std::string S) {
16381638 return S;
16391639}
16401640
1641+ static bool checkEnumTemplateParameter (const EnumDecl *ED,
1642+ DiagnosticsEngine &Diag,
1643+ SourceLocation KernelLocation) {
1644+ if (!ED->isScoped () && !ED->isFixed ()) {
1645+ Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named) << 2 ;
1646+ Diag.Report (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
1647+ << ED;
1648+ return true ;
1649+ }
1650+ return false ;
1651+ }
1652+
16411653// Emits a forward declaration
16421654void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
16431655 SourceLocation KernelLocation) {
@@ -1691,10 +1703,22 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
16911703 PrintingPolicy P (D->getASTContext ().getLangOpts ());
16921704 P.adjustForCPlusPlusFwdDecl ();
16931705 P.SuppressTypedefs = true ;
1706+ P.SuppressUnwrittenScope = true ;
16941707 std::string S;
16951708 llvm::raw_string_ostream SO (S);
16961709 D->print (SO, P);
1697- O << SO.str () << " ;\n " ;
1710+ O << SO.str ();
1711+
1712+ if (const auto *ED = dyn_cast<EnumDecl>(D)) {
1713+ QualType T = ED->getIntegerType ();
1714+ // Backup since getIntegerType() returns null for enum forward
1715+ // declaration with no fixed underlying type
1716+ if (T.isNull ())
1717+ T = ED->getPromotionType ();
1718+ O << " : " << T.getAsString ();
1719+ }
1720+
1721+ O << " ;\n " ;
16981722
16991723 // print closing braces for namespaces if needed
17001724 for (unsigned I = 0 ; I < NamespaceCnt; ++I)
@@ -1763,8 +1787,20 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
17631787
17641788 switch (Arg.getKind ()) {
17651789 case TemplateArgument::ArgKind::Type:
1766- emitForwardClassDecls (O, Arg.getAsType (), KernelLocation, Printed);
1790+ case TemplateArgument::ArgKind::Integral: {
1791+ QualType T = (Arg.getKind () == TemplateArgument::ArgKind::Type)
1792+ ? Arg.getAsType ()
1793+ : Arg.getIntegralType ();
1794+
1795+ // Handle Kernel Name Type templated using enum type and value.
1796+ if (const auto *ET = T->getAs <EnumType>()) {
1797+ const EnumDecl *ED = ET->getDecl ();
1798+ if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
1799+ emitFwdDecl (O, ED, KernelLocation);
1800+ } else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
1801+ emitForwardClassDecls (O, T, KernelLocation, Printed);
17671802 break ;
1803+ }
17681804 case TemplateArgument::ArgKind::Pack: {
17691805 ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray ();
17701806
@@ -1823,6 +1859,97 @@ static std::string getCPPTypeString(QualType Ty) {
18231859 return eraseAnonNamespace (Ty.getAsString (P));
18241860}
18251861
1862+ static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
1863+ ArrayRef<TemplateArgument> Args,
1864+ const PrintingPolicy &P);
1865+
1866+ static void printArgument (ASTContext &Ctx, raw_ostream &ArgOS,
1867+ TemplateArgument Arg, const PrintingPolicy &P) {
1868+ switch (Arg.getKind ()) {
1869+ case TemplateArgument::ArgKind::Pack: {
1870+ printArguments (Ctx, ArgOS, Arg.getPackAsArray (), P);
1871+ break ;
1872+ }
1873+ case TemplateArgument::ArgKind::Integral: {
1874+ QualType T = Arg.getIntegralType ();
1875+ const EnumType *ET = T->getAs <EnumType>();
1876+
1877+ if (ET) {
1878+ const llvm::APSInt &Val = Arg.getAsIntegral ();
1879+ ArgOS << " (" << ET->getDecl ()->getQualifiedNameAsString () << " )" << Val;
1880+ } else {
1881+ Arg.print (P, ArgOS);
1882+ }
1883+ break ;
1884+ }
1885+ case TemplateArgument::ArgKind::Type: {
1886+ LangOptions LO;
1887+ PrintingPolicy TypePolicy (LO);
1888+ TypePolicy.SuppressTypedefs = true ;
1889+ TypePolicy.SuppressTagKeyword = true ;
1890+ QualType T = Arg.getAsType ();
1891+ QualType FullyQualifiedType = TypeName::getFullyQualifiedType (T, Ctx, true );
1892+ ArgOS << FullyQualifiedType.getAsString (TypePolicy);
1893+ break ;
1894+ }
1895+ default :
1896+ Arg.print (P, ArgOS);
1897+ }
1898+ }
1899+
1900+ static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
1901+ ArrayRef<TemplateArgument> Args,
1902+ const PrintingPolicy &P) {
1903+ for (unsigned I = 0 ; I < Args.size (); I++) {
1904+ const TemplateArgument &Arg = Args[I];
1905+
1906+ if (I != 0 )
1907+ ArgOS << " , " ;
1908+
1909+ printArgument (Ctx, ArgOS, Arg, P);
1910+ }
1911+ }
1912+
1913+ static void printTemplateArguments (ASTContext &Ctx, raw_ostream &ArgOS,
1914+ ArrayRef<TemplateArgument> Args,
1915+ const PrintingPolicy &P) {
1916+ ArgOS << " <" ;
1917+ printArguments (Ctx, ArgOS, Args, P);
1918+ ArgOS << " >" ;
1919+ }
1920+
1921+ static std::string getKernelNameTypeString (QualType T) {
1922+
1923+ const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
1924+
1925+ if (!RD)
1926+ return getCPPTypeString (T);
1927+
1928+ // If kernel name type is a template specialization with enum type
1929+ // template parameters, enumerators in name type string should be
1930+ // replaced with their underlying value since the enum definition
1931+ // is not visible in integration header.
1932+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
1933+ LangOptions LO;
1934+ PrintingPolicy P (LO);
1935+ P.SuppressTypedefs = true ;
1936+ SmallString<64 > Buf;
1937+ llvm::raw_svector_ostream ArgOS (Buf);
1938+
1939+ // Print template class name
1940+ TSD->printQualifiedName (ArgOS, P, /* WithGlobalNsPrefix*/ true );
1941+
1942+ // Print template arguments substituting enumerators
1943+ ASTContext &Ctx = RD->getASTContext ();
1944+ const TemplateArgumentList &Args = TSD->getTemplateArgs ();
1945+ printTemplateArguments (Ctx, ArgOS, Args.asArray (), P);
1946+
1947+ return eraseAnonNamespace (ArgOS.str ().str ());
1948+ }
1949+
1950+ return getCPPTypeString (T);
1951+ }
1952+
18261953void SYCLIntegrationHeader::emit (raw_ostream &O) {
18271954 O << " // This is auto-generated SYCL integration header.\n " ;
18281955 O << " \n " ;
@@ -1939,8 +2066,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
19392066 O << " ', '" << c;
19402067 O << " '> {\n " ;
19412068 } else {
1942- O << " template <> struct KernelInfo<" << getCPPTypeString (K.NameType )
1943- << " > {\n " ;
2069+
2070+ O << " template <> struct KernelInfo<"
2071+ << getKernelNameTypeString (K.NameType ) << " > {\n " ;
19442072 }
19452073 O << " DLL_LOCAL\n " ;
19462074 O << " static constexpr const char* getName() { return \" " << K.Name
0 commit comments