@@ -56,7 +56,7 @@ enum KernelInvocationKind {
5656
5757const static std::string InitMethodName = " __init" ;
5858const static std::string FinalizeMethodName = " __finalize" ;
59- constexpr unsigned GPUMaxKernelArgsNum = 2000 ;
59+ constexpr unsigned GPUMaxKernelArgsSize = 2048 ;
6060
6161namespace {
6262
@@ -1656,32 +1656,35 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
16561656 using SyclKernelFieldHandler::leaveStruct;
16571657};
16581658
1659- class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1659+ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
16601660 SourceLocation KernelLoc;
1661- unsigned NumOfParams = 0 ;
1661+ unsigned SizeOfParams = 0 ;
1662+
1663+ void addParam (QualType ArgTy) {
1664+ SizeOfParams +=
1665+ SemaRef.getASTContext ().getTypeSizeInChars (ArgTy).getQuantity ();
1666+ }
16621667
16631668 bool handleSpecialType (QualType FieldTy) {
16641669 const CXXRecordDecl *RecordDecl = FieldTy->getAsCXXRecordDecl ();
16651670 assert (RecordDecl && " The accessor/sampler must be a RecordDecl" );
16661671 CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
16671672 assert (InitMethod && " The accessor/sampler must have the __init method" );
1668- NumOfParams += InitMethod->getNumParams ();
1673+ for (const ParmVarDecl *Param : InitMethod->parameters ())
1674+ addParam (Param->getType ());
16691675 return true ;
16701676 }
16711677
16721678public:
1673- SyclKernelNumArgsChecker (Sema &S, SourceLocation Loc)
1679+ SyclKernelArgsSizeChecker (Sema &S, SourceLocation Loc)
16741680 : SyclKernelFieldHandler(S), KernelLoc(Loc) {}
16751681
1676- ~SyclKernelNumArgsChecker () {
1682+ ~SyclKernelArgsSizeChecker () {
16771683 if (SemaRef.Context .getTargetInfo ().getTriple ().getSubArch () ==
1678- llvm::Triple::SPIRSubArch_gen) {
1679- if (NumOfParams > GPUMaxKernelArgsNum) {
1680- SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_many_args)
1681- << NumOfParams << GPUMaxKernelArgsNum;
1682- SemaRef.Diag (KernelLoc, diag::note_sycl_kernel_args_count);
1683- }
1684- }
1684+ llvm::Triple::SPIRSubArch_gen)
1685+ if (SizeOfParams > GPUMaxKernelArgsSize)
1686+ SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_big_args)
1687+ << SizeOfParams << GPUMaxKernelArgsSize;
16851688 }
16861689
16871690 bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
@@ -1703,12 +1706,12 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
17031706 }
17041707
17051708 bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1706- NumOfParams++ ;
1709+ addParam (FieldTy) ;
17071710 return true ;
17081711 }
17091712
17101713 bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1711- NumOfParams++ ;
1714+ addParam (FieldTy) ;
17121715 return true ;
17131716 }
17141717
@@ -1717,17 +1720,17 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
17171720 }
17181721
17191722 bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1720- NumOfParams++ ;
1723+ addParam (FieldTy) ;
17211724 return true ;
17221725 }
17231726
17241727 bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1725- NumOfParams++ ;
1728+ addParam (FieldTy) ;
17261729 return true ;
17271730 }
17281731 bool handleSyclStreamType (const CXXRecordDecl *, const CXXBaseSpecifier &,
17291732 QualType FieldTy) final {
1730- NumOfParams++ ;
1733+ addParam (FieldTy) ;
17311734 return true ;
17321735 }
17331736 using SyclKernelFieldHandler::handleSyclHalfType;
@@ -2468,7 +2471,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
24682471
24692472 SyclKernelFieldChecker FieldChecker (*this );
24702473 SyclKernelUnionChecker UnionChecker (*this );
2471- SyclKernelNumArgsChecker NumArgsChecker (*this , Args[0 ]->getExprLoc ());
2474+ SyclKernelArgsSizeChecker ArgsSizeChecker (*this , Args[0 ]->getExprLoc ());
24722475 // check that calling kernel conforms to spec
24732476 QualType KernelParamTy = KernelFunc->getParamDecl (0 )->getType ();
24742477 if (KernelParamTy->isReferenceType ()) {
@@ -2488,9 +2491,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
24882491 KernelObjVisitor Visitor{*this };
24892492 DiagnosingSYCLKernel = true ;
24902493 Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
2491- NumArgsChecker );
2494+ ArgsSizeChecker );
24922495 Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
2493- NumArgsChecker );
2496+ ArgsSizeChecker );
24942497 DiagnosingSYCLKernel = false ;
24952498 if (!FieldChecker.isValid () || !UnionChecker.isValid ())
24962499 KernelFunc->setInvalidDecl ();
0 commit comments