@@ -56,7 +56,7 @@ enum KernelInvocationKind {
5656
5757const static std::string InitMethodName = " __init" ;
5858const static std::string FinalizeMethodName = " __finalize" ;
59- constexpr unsigned GPUMaxKernelArgsNum = 2000 ;
59+ constexpr unsigned GPUMaxKernelArgsSize = 2048 ;
6060
6161namespace {
6262
@@ -1656,29 +1656,35 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
16561656 using SyclKernelFieldHandler::leaveStruct;
16571657};
16581658
1659- class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1659+ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
16601660 SourceLocation KernelLoc;
1661- unsigned NumOfParams = 0 ;
1661+ unsigned SizeOfParams = 0 ;
1662+
1663+ void addParam (QualType ArgTy) {
1664+ SizeOfParams +=
1665+ SemaRef.getASTContext ().getTypeSizeInChars (ArgTy).getQuantity ();
1666+ }
16621667
16631668 bool handleSpecialType (QualType FieldTy) {
16641669 const CXXRecordDecl *RecordDecl = FieldTy->getAsCXXRecordDecl ();
16651670 assert (RecordDecl && " The accessor/sampler must be a RecordDecl" );
16661671 CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
16671672 assert (InitMethod && " The accessor/sampler must have the __init method" );
1668- NumOfParams += InitMethod->getNumParams ();
1673+ for (const ParmVarDecl *Param : InitMethod->parameters ())
1674+ addParam (Param->getType ());
16691675 return true ;
16701676 }
16711677
16721678public:
1673- SyclKernelNumArgsChecker (Sema &S, SourceLocation Loc)
1679+ SyclKernelArgsSizeChecker (Sema &S, SourceLocation Loc)
16741680 : SyclKernelFieldHandler(S), KernelLoc(Loc) {}
16751681
1676- ~SyclKernelNumArgsChecker () {
1682+ ~SyclKernelArgsSizeChecker () {
16771683 if (SemaRef.Context .getTargetInfo ().getTriple ().getSubArch () ==
16781684 llvm::Triple::SPIRSubArch_gen) {
1679- if (NumOfParams > GPUMaxKernelArgsNum ) {
1680- SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_many_args )
1681- << NumOfParams << GPUMaxKernelArgsNum ;
1685+ if (SizeOfParams > GPUMaxKernelArgsSize ) {
1686+ SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_big_args )
1687+ << SizeOfParams << GPUMaxKernelArgsSize ;
16821688 SemaRef.Diag (KernelLoc, diag::note_sycl_kernel_args_count);
16831689 }
16841690 }
@@ -1703,12 +1709,12 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
17031709 }
17041710
17051711 bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1706- NumOfParams++ ;
1712+ addParam (FieldTy) ;
17071713 return true ;
17081714 }
17091715
17101716 bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1711- NumOfParams++ ;
1717+ addParam (FieldTy) ;
17121718 return true ;
17131719 }
17141720
@@ -1717,17 +1723,17 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
17171723 }
17181724
17191725 bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1720- NumOfParams++ ;
1726+ addParam (FieldTy) ;
17211727 return true ;
17221728 }
17231729
17241730 bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1725- NumOfParams++ ;
1731+ addParam (FieldTy) ;
17261732 return true ;
17271733 }
17281734 bool handleSyclStreamType (const CXXRecordDecl *, const CXXBaseSpecifier &,
17291735 QualType FieldTy) final {
1730- NumOfParams++ ;
1736+ addParam (FieldTy) ;
17311737 return true ;
17321738 }
17331739 using SyclKernelFieldHandler::handleSyclHalfType;
@@ -2468,7 +2474,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
24682474
24692475 SyclKernelFieldChecker FieldChecker (*this );
24702476 SyclKernelUnionChecker UnionChecker (*this );
2471- SyclKernelNumArgsChecker NumArgsChecker (*this , Args[0 ]->getExprLoc ());
2477+ SyclKernelArgsSizeChecker ArgsSizeChecker (*this , Args[0 ]->getExprLoc ());
24722478 // check that calling kernel conforms to spec
24732479 QualType KernelParamTy = KernelFunc->getParamDecl (0 )->getType ();
24742480 if (KernelParamTy->isReferenceType ()) {
@@ -2488,9 +2494,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
24882494 KernelObjVisitor Visitor{*this };
24892495 DiagnosingSYCLKernel = true ;
24902496 Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
2491- NumArgsChecker );
2497+ ArgsSizeChecker );
24922498 Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
2493- NumArgsChecker );
2499+ ArgsSizeChecker );
24942500 DiagnosingSYCLKernel = false ;
24952501 if (!FieldChecker.isValid () || !UnionChecker.isValid ())
24962502 KernelFunc->setInvalidDecl ();
0 commit comments