1414#include " clang/AST/QualTypeNames.h"
1515#include " clang/AST/RecordLayout.h"
1616#include " clang/AST/RecursiveASTVisitor.h"
17+ #include " clang/AST/TemplateArgumentVisitor.h"
18+ #include " clang/AST/TypeVisitor.h"
1719#include " clang/Analysis/CallGraph.h"
1820#include " clang/Basic/Attributes.h"
1921#include " clang/Basic/Builtins.h"
@@ -2473,9 +2475,111 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
24732475
24742476} // namespace
24752477
2478+ class SYCLKernelNameTypeVisitor
2479+ : public TypeVisitor<SYCLKernelNameTypeVisitor>,
2480+ public ConstTemplateArgumentVisitor<SYCLKernelNameTypeVisitor> {
2481+ Sema &S;
2482+ SourceLocation KernelInvocationFuncLoc;
2483+ using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypeVisitor>;
2484+ using InnerTAVisitor =
2485+ ConstTemplateArgumentVisitor<SYCLKernelNameTypeVisitor>;
2486+
2487+ public:
2488+ SYCLKernelNameTypeVisitor (Sema &S, SourceLocation KernelInvocationFuncLoc)
2489+ : S(S), KernelInvocationFuncLoc(KernelInvocationFuncLoc) {}
2490+
2491+ void Visit (QualType T) {
2492+ if (T.isNull ())
2493+ return ;
2494+ const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
2495+ if (!RD)
2496+ return ;
2497+ // If KernelNameType has template args visit each template arg via
2498+ // ConstTemplateArgumentVisitor
2499+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
2500+ const TemplateArgumentList &Args = TSD->getTemplateArgs ();
2501+ for (unsigned I = 0 ; I < Args.size (); I++) {
2502+ Visit (Args[I]);
2503+ }
2504+ } else {
2505+ InnerTypeVisitor::Visit (T.getTypePtr ());
2506+ }
2507+ }
2508+
2509+ void Visit (const TemplateArgument &TA) {
2510+ if (TA.isNull ())
2511+ return ;
2512+ InnerTAVisitor::Visit (TA);
2513+ }
2514+
2515+ void VisitEnumType (const EnumType *T) {
2516+ const EnumDecl *ED = T->getDecl ();
2517+ if (!ED->isScoped () && !ED->isFixed ()) {
2518+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2519+ << /* Unscoped enum requires fixed underlying type */ 2 ;
2520+ S.Diag (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2521+ << ED;
2522+ }
2523+ }
2524+
2525+ void VisitRecordType (const RecordType *T) {
2526+ return VisitTagDecl (T->getDecl ());
2527+ }
2528+
2529+ void VisitTagDecl (const TagDecl *Tag) {
2530+ bool UnnamedLambdaEnabled =
2531+ S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
2532+ if (!Tag->getDeclContext ()->isTranslationUnit () &&
2533+ !isa<NamespaceDecl>(Tag->getDeclContext ()) && !UnnamedLambdaEnabled) {
2534+ const bool KernelNameIsMissing = Tag->getName ().empty ();
2535+ if (KernelNameIsMissing) {
2536+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2537+ << /* kernel name is missing */ 0 ;
2538+ } else {
2539+ if (Tag->isCompleteDefinition ())
2540+ S.Diag (KernelInvocationFuncLoc,
2541+ diag::err_sycl_kernel_incorrectly_named)
2542+ << /* kernel name is not globally-visible */ 1 ;
2543+ else
2544+ S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
2545+
2546+ S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
2547+ << Tag->getName ();
2548+ }
2549+ }
2550+ }
2551+
2552+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
2553+ QualType T = TA.getAsType ();
2554+ if (const auto *ET = T->getAs <EnumType>())
2555+ VisitEnumType (ET);
2556+ else
2557+ Visit (T);
2558+ }
2559+
2560+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
2561+ QualType T = TA.getIntegralType ();
2562+ if (const EnumType *ET = T->getAs <EnumType>())
2563+ VisitEnumType (ET);
2564+ }
2565+
2566+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
2567+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
2568+ TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
2569+ for (NamedDecl *P : *TemplateParams) {
2570+ if (NonTypeTemplateParmDecl *TemplateParam =
2571+ dyn_cast<NonTypeTemplateParmDecl>(P))
2572+ if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
2573+ VisitEnumType (ET);
2574+ }
2575+ }
2576+ };
2577+
24762578void Sema::CheckSYCLKernelCall (FunctionDecl *KernelFunc, SourceRange CallLoc,
24772579 ArrayRef<const Expr *> Args) {
24782580 const CXXRecordDecl *KernelObj = getKernelObjectType (KernelFunc);
2581+ QualType KernelNameType =
2582+ calculateKernelNameType (getASTContext (), KernelFunc);
24792583 if (!KernelObj) {
24802584 Diag (Args[0 ]->getExprLoc (), diag::err_sycl_kernel_not_function_object);
24812585 KernelFunc->setInvalidDecl ();
@@ -2511,6 +2615,10 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
25112615 return ;
25122616
25132617 KernelObjVisitor Visitor{*this };
2618+ SYCLKernelNameTypeVisitor KernelTypeVisitor (*this , Args[0 ]->getExprLoc ());
2619+ // Emit diagnostics for SYCL device kernels only
2620+ if (LangOpts.SYCLIsDevice )
2621+ KernelTypeVisitor.Visit (KernelNameType);
25142622 DiagnosingSYCLKernel = true ;
25152623 Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
25162624 ArgsSizeChecker);
@@ -2856,18 +2964,6 @@ static void emitWithoutAnonNamespaces(llvm::raw_ostream &OS, StringRef Source) {
28562964 OS << Source;
28572965}
28582966
2859- static bool checkEnumTemplateParameter (const EnumDecl *ED,
2860- DiagnosticsEngine &Diag,
2861- SourceLocation KernelLocation) {
2862- if (!ED->isScoped () && !ED->isFixed ()) {
2863- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named) << 2 ;
2864- Diag.Report (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2865- << ED;
2866- return true ;
2867- }
2868- return false ;
2869- }
2870-
28712967// Emits a forward declaration
28722968void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
28732969 SourceLocation KernelLocation) {
@@ -2880,32 +2976,6 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
28802976 auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
28812977
28822978 if (!NS) {
2883- if (!DC->isTranslationUnit ()) {
2884- const TagDecl *TD = isa<ClassTemplateDecl>(D)
2885- ? cast<ClassTemplateDecl>(D)->getTemplatedDecl ()
2886- : dyn_cast<TagDecl>(D);
2887-
2888- if (TD && !UnnamedLambdaSupport) {
2889- // defined class constituting the kernel name is not globally
2890- // accessible - contradicts the spec
2891- const bool KernelNameIsMissing = TD->getName ().empty ();
2892- if (KernelNameIsMissing) {
2893- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named)
2894- << /* kernel name is missing */ 0 ;
2895- // Don't emit note if kernel name was completely omitted
2896- } else {
2897- if (TD->isCompleteDefinition ())
2898- Diag.Report (KernelLocation,
2899- diag::err_sycl_kernel_incorrectly_named)
2900- << /* kernel name is not globally-visible */ 1 ;
2901- else
2902- Diag.Report (KernelLocation, diag::warn_sycl_implicit_decl);
2903- Diag.Report (D->getSourceRange ().getBegin (),
2904- diag::note_previous_decl)
2905- << TD->getName ();
2906- }
2907- }
2908- }
29092979 break ;
29102980 }
29112981
@@ -3025,7 +3095,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
30253095 // Handle Kernel Name Type templated using enum type and value.
30263096 if (const auto *ET = T->getAs <EnumType>()) {
30273097 const EnumDecl *ED = ET->getDecl ();
3028- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
30293098 emitFwdDecl (O, ED, KernelLocation);
30303099 } else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
30313100 emitForwardClassDecls (O, T, KernelLocation, Printed);
@@ -3085,7 +3154,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
30853154 QualType T = TemplateParam->getType ();
30863155 if (const auto *ET = T->getAs <EnumType>()) {
30873156 const EnumDecl *ED = ET->getDecl ();
3088- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
30893157 emitFwdDecl (O, ED, KernelLocation);
30903158 }
30913159 }
0 commit comments