@@ -3130,21 +3130,14 @@ class SYCLKernelNameTypeVisitor
31303130 void Visit (QualType T) {
31313131 if (T.isNull ())
31323132 return ;
3133+
31333134 const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
3134- if (!RD) {
3135- if (T->isNullPtrType ()) {
3136- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3137- << KernelNameType;
3138- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3139- << /* kernel name cannot be a type in the std namespace */ 2 << T;
3140- IsInvalid = true ;
3141- }
3142- return ;
3143- }
31443135 // If KernelNameType has template args visit each template arg via
31453136 // ConstTemplateArgumentVisitor
3146- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3137+ if (const auto *TSD =
3138+ dyn_cast_or_null<ClassTemplateSpecializationDecl>(RD)) {
31473139 ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs ().asArray ();
3140+
31483141 VisitTemplateArgs (Args);
31493142 } else {
31503143 InnerTypeVisitor::Visit (T.getTypePtr ());
@@ -3157,62 +3150,104 @@ class SYCLKernelNameTypeVisitor
31573150 InnerTemplArgVisitor::Visit (TA);
31583151 }
31593152
3160- void VisitEnumType (const EnumType *T) {
3161- const EnumDecl *ED = T->getDecl ();
3162- if (!ED->isScoped () && !ED->isFixed ()) {
3163- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3153+ void VisitBuiltinType (const BuiltinType *TT) {
3154+ if (TT->isNullPtrType ()) {
3155+ S.Diag (KernelInvocationFuncLoc, diag::err_nullptr_t_type_in_sycl_kernel)
31643156 << KernelNameType;
3165- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3166- << /* Unscoped enum requires fixed underlying type */ 1
3167- << QualType (ED->getTypeForDecl (), 0 );
3157+
31683158 IsInvalid = true ;
31693159 }
3160+ return ;
31703161 }
31713162
3172- void VisitRecordType (const RecordType *T) {
3173- return VisitTagDecl (T->getDecl ());
3174- }
3163+ void VisitTagType (const TagType *TT) {
3164+ return DiagnoseKernelNameType (TT->getDecl ());
3165+ }
3166+
3167+ void DiagnoseKernelNameType (const NamedDecl *DeclNamed) {
3168+ /*
3169+ This is a helper function which throws an error if the kernel name
3170+ declaration is:
3171+ * declared within namespace 'std' (at any level)
3172+ e.g., namespace std { namespace literals { class Whatever; } }
3173+ h.single_task<std::literals::Whatever>([]() {});
3174+ * declared within an anonymous namespace (at any level)
3175+ e.g., namespace foo { namespace { class Whatever; } }
3176+ h.single_task<foo::Whatever>([]() {});
3177+ * declared within a function
3178+ e.g., void foo() { struct S { int i; };
3179+ h.single_task<S>([]() {}); }
3180+ * declared within another tag
3181+ e.g., struct S { struct T { int i } t; };
3182+ h.single_task<S::T>([]() {});
3183+ */
3184+
3185+ if (const auto *ED = dyn_cast<EnumDecl>(DeclNamed)) {
3186+ if (!ED->isScoped () && !ED->isFixed ()) {
3187+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3188+ << /* unscoped enum requires fixed underlying type */ 1
3189+ << DeclNamed;
3190+ IsInvalid = true ;
3191+ }
3192+ }
31753193
3176- void VisitTagDecl (const TagDecl *Tag) {
31773194 bool UnnamedLambdaEnabled =
31783195 S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
3179- const DeclContext *DeclCtx = Tag ->getDeclContext ();
3196+ const DeclContext *DeclCtx = DeclNamed ->getDeclContext ();
31803197 if (DeclCtx && !UnnamedLambdaEnabled) {
3181- auto *NameSpace = dyn_cast_or_null<NamespaceDecl>(DeclCtx);
3182- if (NameSpace && NameSpace->isStdNamespace ()) {
3183- S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
3184- << KernelNameType;
3185- S.Diag (KernelInvocationFuncLoc, diag::note_invalid_type_in_sycl_kernel)
3186- << /* kernel name cannot be a type in the std namespace */ 2
3187- << QualType (Tag->getTypeForDecl (), 0 );
3188- IsInvalid = true ;
3189- return ;
3190- }
3191- if (!DeclCtx->isTranslationUnit () && !isa<NamespaceDecl>(DeclCtx)) {
3192- const bool KernelNameIsMissing = Tag->getName ().empty ();
3193- if (KernelNameIsMissing) {
3194- S.Diag (KernelInvocationFuncLoc,
3195- diag::err_sycl_kernel_incorrectly_named)
3196- << KernelNameType;
3198+
3199+ // Check if the kernel name declaration is declared within namespace
3200+ // "std" or "anonymous" namespace (at any level).
3201+ while (!DeclCtx->isTranslationUnit () && isa<NamespaceDecl>(DeclCtx)) {
3202+ const auto *NSDecl = cast<NamespaceDecl>(DeclCtx);
3203+ if (NSDecl->isStdNamespace ()) {
31973204 S.Diag (KernelInvocationFuncLoc,
3198- diag::note_invalid_type_in_sycl_kernel )
3199- << /* unnamed type used in a SYCL kernel name */ 3 ;
3205+ diag::err_invalid_std_type_in_sycl_kernel )
3206+ << KernelNameType << DeclNamed ;
32003207 IsInvalid = true ;
32013208 return ;
32023209 }
3203- if (Tag-> isCompleteDefinition ()) {
3210+ if (NSDecl-> isAnonymousNamespace ()) {
32043211 S.Diag (KernelInvocationFuncLoc,
32053212 diag::err_sycl_kernel_incorrectly_named)
3213+ << /* kernel name should be globally visible */ 0
32063214 << KernelNameType;
3207- S.Diag (KernelInvocationFuncLoc,
3208- diag::note_invalid_type_in_sycl_kernel)
3209- << /* kernel name is not globally-visible */ 0
3210- << QualType (Tag->getTypeForDecl (), 0 );
32113215 IsInvalid = true ;
3212- } else {
3213- S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
3214- S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
3215- << Tag->getName ();
3216+ return ;
3217+ }
3218+ DeclCtx = DeclCtx->getParent ();
3219+ }
3220+
3221+ // Check if the kernel name is a Tag declaration
3222+ // local to a non-namespace scope (i.e. Inside a function or within
3223+ // another Tag etc).
3224+ if (!DeclCtx->isTranslationUnit () && !isa<NamespaceDecl>(DeclCtx)) {
3225+ if (const auto *Tag = dyn_cast<TagDecl>(DeclNamed)) {
3226+ bool UnnamedLambdaUsed = Tag->getIdentifier () == nullptr ;
3227+
3228+ if (UnnamedLambdaUsed) {
3229+ S.Diag (KernelInvocationFuncLoc,
3230+ diag::err_sycl_kernel_incorrectly_named)
3231+ << /* unnamed lambda used */ 2 << KernelNameType;
3232+
3233+ IsInvalid = true ;
3234+ return ;
3235+ }
3236+ // Check if the declaration is completely defined within a
3237+ // function or class/struct.
3238+
3239+ if (Tag->isCompleteDefinition ()) {
3240+ S.Diag (KernelInvocationFuncLoc,
3241+ diag::err_sycl_kernel_incorrectly_named)
3242+ << /* kernel name should be globally visible */ 0
3243+ << KernelNameType;
3244+
3245+ IsInvalid = true ;
3246+ } else {
3247+ S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
3248+ S.Diag (DeclNamed->getLocation (), diag::note_previous_decl)
3249+ << DeclNamed->getName ();
3250+ }
32163251 }
32173252 }
32183253 }
@@ -3221,15 +3256,15 @@ class SYCLKernelNameTypeVisitor
32213256 void VisitTypeTemplateArgument (const TemplateArgument &TA) {
32223257 QualType T = TA.getAsType ();
32233258 if (const auto *ET = T->getAs <EnumType>())
3224- VisitEnumType (ET);
3259+ VisitTagType (ET);
32253260 else
32263261 Visit (T);
32273262 }
32283263
32293264 void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
32303265 QualType T = TA.getIntegralType ();
32313266 if (const EnumType *ET = T->getAs <EnumType>())
3232- VisitEnumType (ET);
3267+ VisitTagType (ET);
32333268 }
32343269
32353270 void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
@@ -3240,7 +3275,7 @@ class SYCLKernelNameTypeVisitor
32403275 if (NonTypeTemplateParmDecl *TemplateParam =
32413276 dyn_cast<NonTypeTemplateParmDecl>(P))
32423277 if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
3243- VisitEnumType (ET);
3278+ VisitTagType (ET);
32443279 }
32453280 }
32463281
@@ -3301,7 +3336,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
33013336
33023337 // Emit diagnostics for SYCL device kernels only
33033338 if (LangOpts.SYCLIsDevice )
3304- KernelNameTypeVisitor.Visit (KernelNameType);
3339+ KernelNameTypeVisitor.Visit (KernelNameType. getCanonicalType () );
33053340 Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker, DecompMarker);
33063341 Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
33073342 DecompMarker);
0 commit comments