diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 0a19b9930242..f8f1411646d2 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -830,6 +830,13 @@ class KernelObjVisitor { else if (ElementTy->isStructureOrClassType()) VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(), handlers...); + else if (ElementTy->isUnionType()) + // TODO: This check is still necessary I think?! Array seems to handle + // this differently (see above) for structs I think. + //if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) { + VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(), + handlers...); + //} else if (ElementTy->isArrayType()) VisitArrayElements(ArrayField, ElementTy, handlers...); else if (ElementTy->isScalarType()) @@ -857,6 +864,41 @@ class KernelObjVisitor { void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, const CXXRecordDecl *Wrapper, Handlers &... handlers); + // Base case, only calls these when filtered. + template + void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, + const CXXRecordDecl *Wrapper, + FilteredHandlers &... handlers) { + (void)std::initializer_list{ + (handlers.enterUnion(Owner, Parent), 0)...}; + VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); + (void)std::initializer_list{ + (handlers.leaveUnion(Owner, Parent), 0)...}; + } + + + template + std::enable_if_t + VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, + const CXXRecordDecl *Wrapper, + FilteredHandlers &... filtered_handlers, + CurHandler &cur_handler, Handlers &... handlers) { + VisitUnion( + Owner, Parent, Wrapper, filtered_handlers..., handlers...); + } + + template + std::enable_if_t + VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, + const CXXRecordDecl *Wrapper, + FilteredHandlers &... filtered_handlers, + CurHandler &cur_handler, Handlers &... handlers) { + VisitUnion( + Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...); + } + template void VisitRecordHelper(const CXXRecordDecl *Owner, clang::CXXRecordDecl::base_class_const_range Range, @@ -942,6 +984,11 @@ class KernelObjVisitor { CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); VisitRecord(Owner, Field, RD, handlers...); } + } else if (FieldTy->isUnionType()) { + if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) { + CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); + VisitUnion(Owner, Field, RD, handlers...); + } } else if (FieldTy->isReferenceType()) KF_FOR_EACH(handleReferenceType, Field, FieldTy); else if (FieldTy->isPointerType()) @@ -1005,6 +1052,7 @@ class SyclKernelFieldHandler { } virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; } virtual bool handleStructType(FieldDecl *, QualType) { return true; } + virtual bool handleUnionType(FieldDecl *, QualType) { return true; } virtual bool handleReferenceType(FieldDecl *, QualType) { return true; } virtual bool handlePointerType(FieldDecl *, QualType) { return true; } virtual bool handleArrayType(FieldDecl *, QualType) { return true; } @@ -1024,6 +1072,8 @@ class SyclKernelFieldHandler { virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) { return true; } + virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; } + virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; } // The following are used for stepping through array elements. @@ -1201,6 +1251,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { } }; +// A type to check the validity of passing union with accessor/sampler/stream +// member as a kernel argument types. +class SyclKernelUnionBodyChecker : public SyclKernelFieldHandler { + static constexpr const bool VisitUnionBody = true; + int UnionCount = 0; + bool IsInvalid = false; + DiagnosticsEngine &Diag; + + public: + SyclKernelUnionBodyChecker(Sema &S) + : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} + bool isValid() { return !IsInvalid; } + + bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) { + ++UnionCount; + return true; + } + + bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) { + --UnionCount; + return true; + } + + bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { + if (UnionCount) { + IsInvalid = true; + Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << FieldTy; + } + return isValid(); + } + + bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final { + if (UnionCount) { + IsInvalid = true; + Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << FieldTy; + } + return isValid(); + } + + bool handleSyclSamplerType(FieldDecl *FD, QualType FieldTy) final { + if (UnionCount) { + IsInvalid = true; + Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << FieldTy; + } + return isValid(); + } + bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final { + if (UnionCount) { + IsInvalid = true; + Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << FieldTy; + } + return isValid(); + } +}; + // A type to Create and own the FunctionDecl for the kernel. class SyclKernelDeclCreator : public SyclKernelFieldHandler { FunctionDecl *KernelDecl; @@ -1416,6 +1525,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { return true; } + bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { + return handleScalarType(FD, FieldTy); + } + bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy); return true; @@ -1751,6 +1864,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { return true; } + bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { + return handleScalarType(FD, FieldTy); + } + bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { CXXCastPath BasePath; QualType DerivedTy(RD->getTypeForDecl(), 0); @@ -1955,6 +2072,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { return true; } + bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { + return handleScalarType(FD, FieldTy); + } + bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final { addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); return true;