@@ -876,11 +876,12 @@ class KernelObjVisitor {
876876
877877 assert (ElemCount > 0 && " SYCL prohibits 0 sized arrays" );
878878 VisitFirstElement (nullptr , FD, ET, handlers...);
879- (void )std::initializer_list<int >{(handlers.nextElement (ET), 0 )...};
879+ (void )std::initializer_list<int >{(handlers.nextElement (ET, 1 ), 0 )...};
880880
881881 for (int64_t Count = 1 ; Count < ElemCount; Count++) {
882882 VisitNthElement (nullptr , FD, ET, handlers...);
883- (void )std::initializer_list<int >{(handlers.nextElement (ET), 0 )...};
883+ (void )std::initializer_list<int >{
884+ (handlers.nextElement (ET, Count + 1 ), 0 )...};
884885 }
885886
886887 (void )std::initializer_list<int >{
@@ -1085,7 +1086,7 @@ class SyclKernelFieldHandlerBase {
10851086 virtual bool enterField (const CXXRecordDecl *, FieldDecl *) { return true ; }
10861087 virtual bool leaveField (const CXXRecordDecl *, FieldDecl *) { return true ; }
10871088 virtual bool enterArray () { return true ; }
1088- virtual bool nextElement (QualType) { return true ; }
1089+ virtual bool nextElement (QualType, uint64_t ) { return true ; }
10891090 virtual bool leaveArray (FieldDecl *, QualType, int64_t ) { return true ; }
10901091
10911092 virtual ~SyclKernelFieldHandlerBase () = default ;
@@ -1665,7 +1666,6 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
16651666 InitializedEntity VarEntity;
16661667 const CXXRecordDecl *KernelObj;
16671668 llvm::SmallVector<Expr *, 16 > MemberExprBases;
1668- uint64_t ArrayIndex;
16691669 FunctionDecl *KernelCallerFunc;
16701670
16711671 // Using the statements/init expressions that we've created, this generates
@@ -1778,17 +1778,62 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
17781778 InitExprs.push_back (MemberInit.get ());
17791779 }
17801780
1781+ int getDims () {
1782+ int Dims = 0 ;
1783+ for (int i = MemberExprBases.size () - 1 ; i >= 0 ; --i) {
1784+ if (!isa<ArraySubscriptExpr>(MemberExprBases[i]))
1785+ break ;
1786+ ++Dims;
1787+ }
1788+ return Dims;
1789+ }
1790+
1791+ int64_t getArrayIndex (int Idx) {
1792+ ArraySubscriptExpr *LastArrayRef =
1793+ cast<ArraySubscriptExpr>(MemberExprBases[Idx]);
1794+ Expr *LastIdx = LastArrayRef->getIdx ();
1795+ llvm::APSInt Result;
1796+ SemaRef.VerifyIntegerConstantExpression (LastIdx, &Result);
1797+ return Result.getExtValue ();
1798+ }
1799+
17811800 void createExprForScalarElement (FieldDecl *FD) {
1782- InitializedEntity ArrayEntity =
1801+ llvm::SmallVector<InitializedEntity, 4 > InitEntities;
1802+
1803+ // For multi-dimensional arrays, an initialized entity needs to be
1804+ // generated for each 'dimension'. For example, the initialized entity
1805+ // for s.array[x][y][z] is constructed using initialized entities for
1806+ // s.array[x][y], s.array[x] and s.array. InitEntities is used to maintain
1807+ // this.
1808+ InitializedEntity Entity =
17831809 InitializedEntity::InitializeMember (FD, &VarEntity);
1810+ InitEntities.push_back (Entity);
1811+
1812+ // Calculate dimension using ArraySubscriptExpressions in MemberExprBases.
1813+ // Each dimension has an ArraySubscriptExpression (maintains index)
1814+ // in MemberExprBases. For example, if we are currently handling element
1815+ // a[0][0][1], the top of stack entries are ArraySubscriptExpressions for
1816+ // indices 0,0 and 1, with 1 on top.
1817+ int Dims = getDims ();
1818+
1819+ // MemberExprBasesIdx is used to get the index of each dimension, in correct
1820+ // order, from MemberExprBases. For example for a[0][0][1], getArrayIndex
1821+ // will return 0, 0 and then 1.
1822+ int MemberExprBasesIdx = MemberExprBases.size () - Dims;
1823+ for (int I = 0 ; I < Dims; ++I) {
1824+ InitializedEntity NewEntity = InitializedEntity::InitializeElement (
1825+ SemaRef.getASTContext (), getArrayIndex (MemberExprBasesIdx),
1826+ InitEntities.back ());
1827+ InitEntities.push_back (NewEntity);
1828+ ++MemberExprBasesIdx;
1829+ }
1830+
17841831 InitializationKind InitKind =
17851832 InitializationKind::CreateCopy (SourceLocation (), SourceLocation ());
17861833 Expr *DRE = createInitExpr (FD);
1787- InitializedEntity Entity = InitializedEntity::InitializeElement (
1788- SemaRef.getASTContext (), ArrayIndex, ArrayEntity);
1789- ArrayIndex++;
1790- InitializationSequence InitSeq (SemaRef, Entity, InitKind, DRE);
1791- ExprResult MemberInit = InitSeq.Perform (SemaRef, Entity, InitKind, DRE);
1834+ InitializationSequence InitSeq (SemaRef, InitEntities.back (), InitKind, DRE);
1835+ ExprResult MemberInit =
1836+ InitSeq.Perform (SemaRef, InitEntities.back (), InitKind, DRE);
17921837 InitExprs.push_back (MemberInit.get ());
17931838 }
17941839
@@ -1802,7 +1847,22 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
18021847 Expr *ILE = new (SemaRef.getASTContext ())
18031848 InitListExpr (SemaRef.getASTContext (), SourceLocation (), ArrayInitExprs,
18041849 SourceLocation ());
1805- ILE->setType (FD->getType ());
1850+
1851+ // We need to find the type of the element for which we are generating the
1852+ // InitListExpr. For example, for a multi-dimensional array say a[2][3][2],
1853+ // the types for InitListExpr of the array and its 'sub-arrays' are -
1854+ // int [2][3][2], int [3][2] and int [2]. This loop is used to obtain this
1855+ // information from MemberExprBases. MemberExprBases holds
1856+ // ArraySubscriptExprs and the top of stack shows how far we have descended
1857+ // down the array. getDims() calculates this depth.
1858+ QualType ILEType = FD->getType ();
1859+ for (int I = getDims (); I > 1 ; I--) {
1860+ const ConstantArrayType *CAT =
1861+ SemaRef.getASTContext ().getAsConstantArrayType (ILEType);
1862+ assert (CAT && " Should only be called on constant-size array." );
1863+ ILEType = CAT->getElementType ();
1864+ }
1865+ ILE->setType (ILEType);
18061866 InitExprs.push_back (ILE);
18071867 }
18081868
@@ -2063,20 +2123,18 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
20632123 ExprResult ElementBase = SemaRef.CreateBuiltinArraySubscriptExpr (
20642124 ArrayBase, SourceLocation (), IndexExpr.get (), SourceLocation ());
20652125 MemberExprBases.push_back (ElementBase.get ());
2066- ArrayIndex = 0 ;
20672126 return true ;
20682127 }
20692128
2070- bool nextElement (QualType ET) final {
2071- ArraySubscriptExpr *LastArrayRef =
2072- cast<ArraySubscriptExpr>(MemberExprBases.back ());
2129+ bool nextElement (QualType ET, uint64_t ) final {
2130+ // Top of MemberExprBases holds ArraySubscriptExpression of element
2131+ // we just handled, or the Array base for the dimension we are
2132+ // currently visiting.
2133+ int64_t nextIndex = getArrayIndex (MemberExprBases.size () - 1 ) + 1 ;
20732134 MemberExprBases.pop_back ();
2074- Expr *LastIdx = LastArrayRef->getIdx ();
2075- llvm::APSInt Result;
2076- SemaRef.VerifyIntegerConstantExpression (LastIdx, &Result);
20772135 Expr *ArrayBase = MemberExprBases.back ();
2078- ExprResult IndexExpr = SemaRef. ActOnIntegerConstant (
2079- SourceLocation (), Result. getExtValue () + 1 );
2136+ ExprResult IndexExpr =
2137+ SemaRef. ActOnIntegerConstant ( SourceLocation (), nextIndex );
20802138 ExprResult ElementBase = SemaRef.CreateBuiltinArraySubscriptExpr (
20812139 ArrayBase, SourceLocation (), IndexExpr.get (), SourceLocation ());
20822140 MemberExprBases.push_back (ElementBase.get ());
@@ -2101,6 +2159,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
21012159class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
21022160 SYCLIntegrationHeader &Header;
21032161 int64_t CurOffset = 0 ;
2162+ llvm::SmallVector<size_t , 16 > ArrayBaseOffsets;
21042163 int StructDepth = 0 ;
21052164
21062165 // A series of functions to calculate the change in offset based on the type.
@@ -2248,18 +2307,20 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
22482307 return true ;
22492308 }
22502309
2251- bool nextElement (QualType ET ) final {
2252- CurOffset += SemaRef. getASTContext (). getTypeSizeInChars (ET). getQuantity ( );
2310+ bool enterArray ( ) final {
2311+ ArrayBaseOffsets. push_back (CurOffset );
22532312 return true ;
22542313 }
22552314
2256- bool leaveArray (FieldDecl *, QualType ET, int64_t Count) final {
2257- int64_t ArraySize =
2258- SemaRef.getASTContext ().getTypeSizeInChars (ET).getQuantity ();
2259- if (!ET->isArrayType ()) {
2260- ArraySize *= Count;
2261- }
2262- CurOffset -= ArraySize;
2315+ bool nextElement (QualType ET, uint64_t Index) final {
2316+ int64_t Size = SemaRef.getASTContext ().getTypeSizeInChars (ET).getQuantity ();
2317+ CurOffset = ArrayBaseOffsets.back () + Size * (Index);
2318+ return true ;
2319+ }
2320+
2321+ bool leaveArray (FieldDecl *, QualType ET, int64_t ) final {
2322+ CurOffset = ArrayBaseOffsets.back ();
2323+ ArrayBaseOffsets.pop_back ();
22632324 return true ;
22642325 }
22652326 using SyclKernelFieldHandler::enterStruct;
0 commit comments