diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 5477147276a80..454fbdb2da264 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { return Res; }; - QualType FieldType = Field->getType(); - CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl(); - if (CRD && Util::isSyclAccessorType(FieldType)) { + auto getExprForAccessorInit = [&](const QualType ¶mTy, + FieldDecl *Field, + const CXXRecordDecl *CRD, Expr *Base) { // Since this is an accessor next 4 TargetFuncParams including current // should be set in __init method: _ValueType*, range, range, // id @@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { std::advance(TargetFuncParam, NumParams - 1); DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none); - // kernel_obj.accessor + // [kenrel_obj or wrapper object].accessor auto AccessorME = MemberExpr::Create( - S.Context, CloneRef, false, SourceLocation(), + S.Context, Base, false, SourceLocation(), NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP, DeclarationNameInfo(Field->getDeclName(), SourceLocation()), nullptr, Field->getType(), VK_LValue, OK_Ordinary); @@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { } assert(InitMethod && "The accessor must have the __init method"); - // kernel_obj.accessor.__init + // [kenrel_obj or wrapper object].accessor.__init DeclAccessPair MethodDAP = DeclAccessPair::make(InitMethod, AS_none); auto ME = MemberExpr::Create( S.Context, AccessorME, false, SourceLocation(), @@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { S, ((*ParamItr++))->getOriginalType(), ParamDREs[2])); ParamStmts.push_back(getExprForRangeOrOffset( S, ((*ParamItr++))->getOriginalType(), ParamDREs[3])); - // kernel_obj.accessor.__init(_ValueType*, range, range, - // id) + // [kenrel_obj or wrapper object].accessor.__init(_ValueType*, + // range, range, id) CXXMemberCallExpr *Call = CXXMemberCallExpr::Create( S.Context, ME, ParamStmts, ResultTy, VK, SourceLocation()); BodyStmts.push_back(Call); + }; + + // Recursively search for accessor fields to initialize them with kernel + // parameters + std::function + getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD, + Expr *Base) { + for (auto *WrapperFld : CRD->fields()) { + QualType FldType = WrapperFld->getType(); + CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl(); + if (FldType->isStructureOrClassType()) { + if (Util::isSyclAccessorType(FldType)) { + // Accessor field found - create expr to initialize this + // accessor object. Need to start from the next target + // function parameter, since current one is the wrapper object + // or parameter of the previous processed accessor object. + TargetFuncParam++; + getExprForAccessorInit(FldType, WrapperFld, WrapperFldCRD, + Base); + } else { + // Field is a structure or class so change the wrapper object + // and recursively search for accessor field. + DeclAccessPair WrapperFieldDAP = + DeclAccessPair::make(WrapperFld, AS_none); + auto NewBase = MemberExpr::Create( + S.Context, Base, false, SourceLocation(), + NestedNameSpecifierLoc(), SourceLocation(), WrapperFld, + WrapperFieldDAP, + DeclarationNameInfo(WrapperFld->getDeclName(), + SourceLocation()), + nullptr, WrapperFld->getType(), VK_LValue, OK_Ordinary); + getExprForWrappedAccessorInit(WrapperFldCRD, NewBase); + } + } + } + }; + + QualType FieldType = Field->getType(); + CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl(); + if (Util::isSyclAccessorType(FieldType)) { + getExprForAccessorInit(FieldType, Field, CRD, CloneRef); } else if (CRD && Util::isSyclSamplerType(FieldType)) { // Sampler has only one TargetFuncParam, which should be set in @@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { BinaryOperator(Lhs, Rhs, BO_Assign, FieldType, VK_LValue, OK_Ordinary, SourceLocation(), FPOptions()); BodyStmts.push_back(Res); + + // If a structure/class type has accessor fields then we need to + // initialize these accessors in proper way by calling __init method of + // the accessor and passing corresponding kernel parameters. + if (CRD) + getExprForWrappedAccessorInit(CRD, Lhs); } else { llvm_unreachable("unsupported field type"); } @@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, // create a parameter descriptor and append it to the result ParamDescs.push_back(makeParamDesc(Fld, ArgType)); }; + + auto createAccessorParamDesc = [&](const FieldDecl *Fld, + const QualType &ArgTy) { + // the parameter is a SYCL accessor object + const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); + assert(RecordDecl && "accessor must be of a record type"); + const auto *TemplateDecl = + cast(RecordDecl); + // First accessor template parameter - data type + QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); + // Fourth parameter - access target + target AccessTarget = getAccessTarget(TemplateDecl); + Qualifiers Quals = PointeeType.getQualifiers(); + // TODO: Support all access targets + switch (AccessTarget) { + case target::global_buffer: + Quals.setAddressSpace(LangAS::opencl_global); + break; + case target::constant_buffer: + Quals.setAddressSpace(LangAS::opencl_constant); + break; + case target::local: + Quals.setAddressSpace(LangAS::opencl_local); + break; + default: + llvm_unreachable("Unsupported access target"); + } + PointeeType = + Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals); + QualType PointerType = Context.getPointerType(PointeeType); + + CreateAndAddPrmDsc(Fld, PointerType); + + FieldDecl *AccessRangeFld = + getFieldDeclByName(RecordDecl, {"impl", "AccessRange"}); + assert(AccessRangeFld && + "The accessor.impl must contain the AccessRange field"); + CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType()); + + FieldDecl *MemRangeFld = + getFieldDeclByName(RecordDecl, {"impl", "MemRange"}); + assert(MemRangeFld && "The accessor.impl must contain the MemRange field"); + CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType()); + + FieldDecl *OffsetFld = getFieldDeclByName(RecordDecl, {"impl", "Offset"}); + assert(OffsetFld && "The accessor.impl must contain the Offset field"); + CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType()); + }; + + std::function + createParamDescForWrappedAccessors = + [&](const FieldDecl *Fld, const QualType &ArgTy) { + const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); + for (const auto *WrapperFld : Wrapper->fields()) { + QualType FldType = WrapperFld->getType(); + if (FldType->isStructureOrClassType()) { + if (Util::isSyclAccessorType(FldType)) { + // accessor field is found - create descriptor + createAccessorParamDesc(WrapperFld, FldType); + } else { + // field is some class or struct - recursively check for + // accessor fields + createParamDescForWrappedAccessors(WrapperFld, FldType); + } + } + } + }; + for (const auto *Fld : KernelObj->fields()) { QualType ArgTy = Fld->getType(); if (Util::isSyclAccessorType(ArgTy)) { - // the parameter is a SYCL accessor object - const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); - assert(RecordDecl && "accessor must be of a record type"); - const auto *TemplateDecl = - cast(RecordDecl); - // First accessor template parameter - data type - QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); - // Fourth parameter - access target - target AccessTarget = getAccessTarget(TemplateDecl); - Qualifiers Quals = PointeeType.getQualifiers(); - // TODO: Support all access targets - switch (AccessTarget) { - case target::global_buffer: - Quals.setAddressSpace(LangAS::opencl_global); - break; - case target::constant_buffer: - Quals.setAddressSpace(LangAS::opencl_constant); - break; - case target::local: - Quals.setAddressSpace(LangAS::opencl_local); - break; - default: - llvm_unreachable("Unsupported access target"); - } - // TODO: get address space from accessor template parameter. - PointeeType = - Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals); - QualType PointerType = Context.getPointerType(PointeeType); - - CreateAndAddPrmDsc(Fld, PointerType); - - FieldDecl *AccessRangeFld = - getFieldDeclByName(RecordDecl, {"impl", "AccessRange"}); - assert(AccessRangeFld && - "The accessor.impl must contain the AccessRange field"); - CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType()); - - FieldDecl *MemRangeFld = - getFieldDeclByName(RecordDecl, {"impl", "MemRange"}); - assert(MemRangeFld && - "The accessor.impl must contain the MemRange field"); - CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType()); - - FieldDecl *OffsetFld = - getFieldDeclByName(RecordDecl, {"impl", "Offset"}); - assert(OffsetFld && "The accessor.impl must contain the Offset field"); - CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType()); + createAccessorParamDesc(Fld, ArgTy); } else if (Util::isSyclSamplerType(ArgTy)) { // the parameter is a SYCL sampler object const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); @@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, } // structure or class typed parameter - the same handling as a scalar CreateAndAddPrmDsc(Fld, ArgTy); + // create descriptors for each accessor field in the class or struct + createParamDescForWrappedAccessors(Fld, ArgTy); } else if (ArgTy->isScalarType()) { // scalar typed parameter CreateAndAddPrmDsc(Fld, ArgTy); @@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(KernelObjTy); H.startKernel(Name, NameType); - for (const auto Fld : KernelObjTy->fields()) { - QualType ActualArgType; - QualType ArgTy = Fld->getType(); - - // Get offset in bytes - uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8; - - if (Util::isSyclAccessorType(ArgTy)) { + auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) { // The parameter is a SYCL accessor object. // The Info field of the parameter descriptor for accessor contains // two template parameters packed into thid integer field: @@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue()); int Info = getAccessTarget(AccTmplTy) | (Dims << 11); H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset); + }; + + std::function + populateHeaderForWrappedAccessors = [&](const QualType &ArgTy, + uint64_t Offset) { + const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); + for (const auto *WrapperFld : Wrapper->fields()) { + QualType FldType = WrapperFld->getType(); + if (FldType->isStructureOrClassType()) { + ASTContext &WrapperCtx = Wrapper->getASTContext(); + const ASTRecordLayout &WrapperLayout = + WrapperCtx.getASTRecordLayout(Wrapper); + // Get offset (in bytes) of the field in wrapper class or struct + uint64_t OffsetInWrapper = + WrapperLayout.getFieldOffset(WrapperFld->getFieldIndex()) / 8; + if (Util::isSyclAccessorType(FldType)) { + // This is an accesor - populate the header appropriately + populateHeaderForAccessor(FldType, Offset + OffsetInWrapper); + } else { + // This is an other class or struct - recursively search for an + // accessor field + populateHeaderForWrappedAccessors(FldType, + Offset + OffsetInWrapper); + } + } + } + }; + + for (const auto Fld : KernelObjTy->fields()) { + QualType ActualArgType; + QualType ArgTy = Fld->getType(); + + // Get offset in bytes + uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8; + + if (Util::isSyclAccessorType(ArgTy)) { + populateHeaderForAccessor(ArgTy, Offset); } else if (Util::isSyclSamplerType(ArgTy)) { // The parameter is a SYCL sampler object // It has only one descriptor, "m_Sampler" @@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity(); H.addParamDesc(SYCLIntegrationHeader::kind_std_layout, static_cast(Sz), static_cast(Offset)); + + // check for accessor fields in structure or class and populate the + // integration header appropriately + if (ArgTy->isStructureOrClassType()) { + populateHeaderForWrappedAccessors(ArgTy, Offset); + } } else { llvm_unreachable("unsupported kernel parameter type"); } diff --git a/clang/test/CodeGenSYCL/wrapped-accessor.cpp b/clang/test/CodeGenSYCL/wrapped-accessor.cpp new file mode 100644 index 0000000000000..56e51d3d6fe75 --- /dev/null +++ b/clang/test/CodeGenSYCL/wrapped-accessor.cpp @@ -0,0 +1,51 @@ +// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o %T/kernel.spv +// RUN: FileCheck -input-file=%t.h %s +// +// CHECK: #include + +// CHECK: class wrapped_access; + +// CHECK: namespace cl { +// CHECK-NEXT: namespace sycl { +// CHECK-NEXT: namespace detail { + +// CHECK: static constexpr +// CHECK-NEXT: const char* const kernel_names[] = { +// CHECK-NEXT: "_ZTSZ4mainE14wrapped_access" +// CHECK-NEXT: }; + +// CHECK: static constexpr +// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = { +// CHECK-NEXT: //--- _ZTSZ4mainE14wrapped_access +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 3, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, +// CHECK-EMPTY: +// CHECK-NEXT: }; + +// CHECK: static constexpr +// CHECK-NEXT: const unsigned kernel_signature_start[] = { +// CHECK-NEXT: 0 // _ZTSZ4mainE14wrapped_access +// CHECK-NEXT: }; + +// CHECK: template struct KernelInfo; + +// CHECK: template <> struct KernelInfo { + +#include + +template +struct AccWrapper { Acc accessor; }; + +template +__attribute__((sycl_kernel)) void kernel(Func kernelFunc) { + kernelFunc(); +} + +int main() { + cl::sycl::accessor acc; + auto acc_wrapped = AccWrapper{acc}; + kernel( + [=]() { + acc_wrapped.accessor.use(); + }); +} diff --git a/clang/test/SemaSYCL/wrapped-accessor.cpp b/clang/test/SemaSYCL/wrapped-accessor.cpp new file mode 100644 index 0000000000000..15104bef2b0b7 --- /dev/null +++ b/clang/test/SemaSYCL/wrapped-accessor.cpp @@ -0,0 +1,59 @@ +// RUN: %clang_cc1 -I %S/Inputs -fsycl-is-device -ast-dump %s | FileCheck %s + +#include + +template +struct AccWrapper { Acc accessor; }; + +template +__attribute__((sycl_kernel)) void kernel(Func kernelFunc) { + kernelFunc(); +} + +int main() { + cl::sycl::accessor acc; + auto acc_wrapped = AccWrapper{acc}; + kernel( + [=]() { + acc_wrapped.accessor.use(); + }); +} + +// Check declaration of the kernel +// CHECK: wrapped_access 'void (AccWrapper >, __global int *, range<1>, range<1>, id<1>)' + +// Check parameters of the kernel +// CHECK: ParmVarDecl {{.*}} used _arg_ 'AccWrapper >':'AccWrapper >' +// CHECK: ParmVarDecl {{.*}} used _arg_accessor '__global int *' +// CHECK: ParmVarDecl {{.*}} used _arg_AccessRange 'range<1>':'cl::sycl::range<1>' +// CHECK: ParmVarDecl {{.*}} used _arg_MemRange 'range<1>':'cl::sycl::range<1>' +// CHECK: ParmVarDecl {{.*}} used _arg_Offset 'id<1>':'cl::sycl::id<1>' + +// Check that wrapper object itself is initialized with corresponding kernel argument using operator= +// CHECK: BinaryOperator {{.*}} 'AccWrapper >':'AccWrapper >' lvalue '=' + +// Left operand is the field of the kernel object +// CHECK-NEXT: MemberExpr {{.*}} 'AccWrapper >':'AccWrapper >' lvalue . {{.*}} +// CHECK-NEXT: DeclRefExpr {{.*}} '(lambda at {{.*}}wrapped-accessor.cpp:17:7)' lvalue Var {{.*}} '(lambda at {{.*}}wrapped-accessor.cpp:17:7)' + +// Right operand is the kernel argument +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'AccWrapper >':'AccWrapper >' +// CHECK-NEXT: DeclRefExpr {{.*}} 'AccWrapper >':'AccWrapper >' lvalue ParmVar {{.*}} '_arg_' 'AccWrapper >':'AccWrapper >' + +// Check that accessor field of the wrapper object is initialized using __init method +// CHECK-NEXT: CXXMemberCallExpr {{.*}} 'void' +// CHECK-NEXT: MemberExpr {{.*}} 'void (__global int *, range<1>, range<1>, id<1>)' lvalue .__init +// CHECK-NEXT: MemberExpr {{.*}} 'cl::sycl::accessor':'cl::sycl::accessor' lvalue .accessor {{.*}} +// CHECK-NEXT: MemberExpr {{.*}} 'AccWrapper >':'AccWrapper >' lvalue . +// CHECK-NEXT: DeclRefExpr {{.*}} '(lambda at {{.*}}wrapped-accessor.cpp:17:7)' lvalue Var {{.*}} '(lambda at {{.*}}wrapped-accessor.cpp:17:7)' + +// Parameters of the _init method +// CHECK-NEXT: ImplicitCastExpr {{.*}} '__global int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} '__global int *' lvalue +// CHECK-NEXT: DeclRefExpr {{.*}} '__global int *' lvalue ParmVar {{.*}} '_arg_accessor' '__global int *' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'range<1>':'cl::sycl::range<1>' +// CHECK-NEXT: DeclRefExpr {{.*}} 'range<1>':'cl::sycl::range<1>' lvalue ParmVar {{.*}} '_arg_AccessRange' 'range<1>':'cl::sycl::range<1>' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'range<1>':'cl::sycl::range<1>' +// CHECK-NEXT: DeclRefExpr {{.*}} 'range<1>':'cl::sycl::range<1>' lvalue ParmVar {{.*}} '_arg_MemRange' 'range<1>':'cl::sycl::range<1>' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'id<1>':'cl::sycl::id<1>' +// CHECK-NEXT: DeclRefExpr {{.*}} 'id<1>':'cl::sycl::id<1>' lvalue ParmVar {{.*}} '_arg_Offset' 'id<1>':'cl::sycl::id<1>' diff --git a/sycl/test/basic_tests/accessor/accessor.cpp b/sycl/test/basic_tests/accessor/accessor.cpp index c530730e3391c..13291a8510fad 100644 --- a/sycl/test/basic_tests/accessor/accessor.cpp +++ b/sycl/test/basic_tests/accessor/accessor.cpp @@ -40,6 +40,27 @@ struct IdxSzT { operator size_t() { return x; } }; +template struct AccWrapper { Acc accessor; }; + +template struct AccsWrapper { + int a; + Acc1 accessor1; + int b; + Acc2 accessor2; +}; + +struct Wrapper1 { + int a; + int b; +}; + +template struct Wrapper2 { + Wrapper1 w1; + AccWrapper wrapped; +}; + +template struct Wrapper3 { Wrapper2 w2; }; + int main() { // Host accessor. { @@ -220,4 +241,94 @@ int main() { return 1; } } + + // Check that accessor is initialized when accessor is wrapped to some class. + { + sycl::queue queue; + if (!queue.is_host()) { + int array[10] = {0}; + { + sycl::buffer buf((int *)array, sycl::range<1>(10), + {cl::sycl::property::buffer::use_host_ptr()}); + queue.submit([&](sycl::handler &cgh) { + auto acc = buf.get_access(cgh); + auto acc_wrapped = AccWrapper{acc}; + cgh.parallel_for( + sycl::range<1>(buf.get_count()), [=](sycl::item<1> it) { + auto idx = it.get_linear_id(); + acc_wrapped.accessor[idx] = 333; + }); + }); + queue.wait(); + } + for (int i = 0; i < 10; i++) { + std::cout << "array[" << i << "]=" << array[i] << std::endl; + assert(array[i] == 333); + } + } + } + + // Case when several accessors are wrapped to some class. Check that they are + // initialized in proper way and value is assigned. + { + sycl::queue queue; + if (!queue.is_host()) { + int array1[10] = {0}; + int array2[10] = {0}; + { + sycl::buffer buf1((int *)array1, sycl::range<1>(10), + {cl::sycl::property::buffer::use_host_ptr()}); + sycl::buffer buf2((int *)array2, sycl::range<1>(10), + {cl::sycl::property::buffer::use_host_ptr()}); + queue.submit([&](sycl::handler &cgh) { + auto acc1 = buf1.get_access(cgh); + auto acc2 = buf2.get_access(cgh); + auto acc_wrapped = + AccsWrapper{10, acc1, 5, acc2}; + cgh.parallel_for( + sycl::range<1>(10), [=](sycl::item<1> it) { + auto idx = it.get_linear_id(); + acc_wrapped.accessor1[idx] = 333; + acc_wrapped.accessor2[idx] = 666; + }); + }); + queue.wait(); + } + for (int i = 0; i < 10; i++) { + std::cout << "array1[" << i << "]=" << array1[i] << std::endl; + std::cout << "array2[" << i << "]=" << array2[i] << std::endl; + assert(array1[i] == 333); + assert(array2[i] == 666); + } + } + } + + // Several levels of wrappers for accessor. + { + sycl::queue queue; + if (!queue.is_host()) { + int array[10] = {0}; + { + sycl::buffer buf((int *)array, sycl::range<1>(10), + {cl::sycl::property::buffer::use_host_ptr()}); + queue.submit([&](sycl::handler &cgh) { + auto acc = buf.get_access(cgh); + auto acc_wrapped = AccWrapper{acc}; + Wrapper1 wr1; + auto wr2 = Wrapper2{wr1, acc_wrapped}; + auto wr3 = Wrapper3{wr2}; + cgh.parallel_for( + sycl::range<1>(buf.get_count()), [=](sycl::item<1> it) { + auto idx = it.get_linear_id(); + wr3.w2.wrapped.accessor[idx] = 333; + }); + }); + queue.wait(); + } + for (int i = 0; i < 10; i++) { + std::cout << "array[" << i << "]=" << array[i] << std::endl; + assert(array[i] == 333); + } + } + } }