@@ -830,6 +830,13 @@ class KernelObjVisitor {
830830 else if (ElementTy->isStructureOrClassType ())
831831 VisitRecord (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
832832 handlers...);
833+ else if (ElementTy->isUnionType ())
834+ // TODO: This check is still necessary I think?! Array seems to handle
835+ // this differently (see above) for structs I think.
836+ // if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
837+ VisitUnion (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
838+ handlers...);
839+ // }
833840 else if (ElementTy->isArrayType ())
834841 VisitArrayElements (ArrayField, ElementTy, handlers...);
835842 else if (ElementTy->isScalarType ())
@@ -857,6 +864,41 @@ class KernelObjVisitor {
857864 void VisitRecord (const CXXRecordDecl *Owner, ParentTy &Parent,
858865 const CXXRecordDecl *Wrapper, Handlers &... handlers);
859866
867+ // Base case, only calls these when filtered.
868+ template <typename ... FilteredHandlers, typename ParentTy>
869+ void VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
870+ const CXXRecordDecl *Wrapper,
871+ FilteredHandlers &... handlers) {
872+ (void )std::initializer_list<int >{
873+ (handlers.enterUnion (Owner, Parent), 0 )...};
874+ VisitRecordHelper (Wrapper, Wrapper->fields (), handlers...);
875+ (void )std::initializer_list<int >{
876+ (handlers.leaveUnion (Owner, Parent), 0 )...};
877+ }
878+
879+
880+ template <typename ... FilteredHandlers, typename ParentTy,
881+ typename CurHandler, typename ... Handlers>
882+ std::enable_if_t <!CurHandler::VisitUnionBody>
883+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
884+ const CXXRecordDecl *Wrapper,
885+ FilteredHandlers &... filtered_handlers,
886+ CurHandler &cur_handler, Handlers &... handlers) {
887+ VisitUnion<FilteredHandlers...>(
888+ Owner, Parent, Wrapper, filtered_handlers..., handlers...);
889+ }
890+
891+ template <typename ... FilteredHandlers, typename ParentTy,
892+ typename CurHandler, typename ... Handlers>
893+ std::enable_if_t <CurHandler::VisitUnionBody>
894+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
895+ const CXXRecordDecl *Wrapper,
896+ FilteredHandlers &... filtered_handlers,
897+ CurHandler &cur_handler, Handlers &... handlers) {
898+ VisitUnion<FilteredHandlers..., CurHandler>(
899+ Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...);
900+ }
901+
860902 template <typename ... Handlers>
861903 void VisitRecordHelper (const CXXRecordDecl *Owner,
862904 clang::CXXRecordDecl::base_class_const_range Range,
@@ -942,6 +984,11 @@ class KernelObjVisitor {
942984 CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
943985 VisitRecord (Owner, Field, RD, handlers...);
944986 }
987+ } else if (FieldTy->isUnionType ()) {
988+ if (KF_FOR_EACH (handleUnionType, Field, FieldTy)) {
989+ CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
990+ VisitUnion (Owner, Field, RD, handlers...);
991+ }
945992 } else if (FieldTy->isReferenceType ())
946993 KF_FOR_EACH (handleReferenceType, Field, FieldTy);
947994 else if (FieldTy->isPointerType ())
@@ -1005,6 +1052,7 @@ class SyclKernelFieldHandler {
10051052 }
10061053 virtual bool handleSyclHalfType (FieldDecl *, QualType) { return true ; }
10071054 virtual bool handleStructType (FieldDecl *, QualType) { return true ; }
1055+ virtual bool handleUnionType (FieldDecl *, QualType) { return true ; }
10081056 virtual bool handleReferenceType (FieldDecl *, QualType) { return true ; }
10091057 virtual bool handlePointerType (FieldDecl *, QualType) { return true ; }
10101058 virtual bool handleArrayType (FieldDecl *, QualType) { return true ; }
@@ -1024,6 +1072,8 @@ class SyclKernelFieldHandler {
10241072 virtual bool leaveStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {
10251073 return true ;
10261074 }
1075+ virtual bool enterUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1076+ virtual bool leaveUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
10271077
10281078 // The following are used for stepping through array elements.
10291079
@@ -1201,6 +1251,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
12011251 }
12021252};
12031253
1254+ // A type to check the validity of passing union with accessor/sampler/stream
1255+ // member as a kernel argument types.
1256+ class SyclKernelUnionBodyChecker : public SyclKernelFieldHandler {
1257+ static constexpr const bool VisitUnionBody = true ;
1258+ int UnionCount = 0 ;
1259+ bool IsInvalid = false ;
1260+ DiagnosticsEngine &Diag;
1261+
1262+ public:
1263+ SyclKernelUnionBodyChecker (Sema &S)
1264+ : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1265+ bool isValid () { return !IsInvalid; }
1266+
1267+ bool enterUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1268+ ++UnionCount;
1269+ return true ;
1270+ }
1271+
1272+ bool leaveUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1273+ --UnionCount;
1274+ return true ;
1275+ }
1276+
1277+ bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1278+ if (UnionCount) {
1279+ IsInvalid = true ;
1280+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1281+ << FieldTy;
1282+ }
1283+ return isValid ();
1284+ }
1285+
1286+ bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
1287+ if (UnionCount) {
1288+ IsInvalid = true ;
1289+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1290+ << FieldTy;
1291+ }
1292+ return isValid ();
1293+ }
1294+
1295+ bool handleSyclSamplerType (FieldDecl *FD, QualType FieldTy) final {
1296+ if (UnionCount) {
1297+ IsInvalid = true ;
1298+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1299+ << FieldTy;
1300+ }
1301+ return isValid ();
1302+ }
1303+ bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1304+ if (UnionCount) {
1305+ IsInvalid = true ;
1306+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1307+ << FieldTy;
1308+ }
1309+ return isValid ();
1310+ }
1311+ };
1312+
12041313// A type to Create and own the FunctionDecl for the kernel.
12051314class SyclKernelDeclCreator : public SyclKernelFieldHandler {
12061315 FunctionDecl *KernelDecl;
@@ -1416,6 +1525,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
14161525 return true ;
14171526 }
14181527
1528+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1529+ return handleScalarType (FD, FieldTy);
1530+ }
1531+
14191532 bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
14201533 addParam (FD, FieldTy);
14211534 return true ;
@@ -1751,6 +1864,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
17511864 return true ;
17521865 }
17531866
1867+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1868+ return handleScalarType (FD, FieldTy);
1869+ }
1870+
17541871 bool enterStruct (const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
17551872 CXXCastPath BasePath;
17561873 QualType DerivedTy (RD->getTypeForDecl (), 0 );
@@ -1955,6 +2072,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
19552072 return true ;
19562073 }
19572074
2075+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
2076+ return handleScalarType (FD, FieldTy);
2077+ }
2078+
19582079 bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
19592080 addParam (FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
19602081 return true ;
0 commit comments