@@ -320,6 +320,32 @@ llvm::Value *CodeGenFunction::getTypeSize(QualType Ty) {
320320 return CGM.getSize (SizeInChars);
321321}
322322
323+ void CodeGenFunction::GenerateOpenMPCapturedVarsAggregate (
324+ const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
325+ const RecordDecl *RD = S.getCapturedRecordDecl ();
326+ QualType RecordTy = getContext ().getRecordType (RD);
327+ // Create the aggregate argument struct for the outlined function.
328+ LValue AggLV = MakeAddrLValue (
329+ CreateMemTemp (RecordTy, " omp.outlined.arg.agg." ), RecordTy);
330+
331+ // Initialize the aggregate with captured values.
332+ auto CurField = RD->field_begin ();
333+ for (CapturedStmt::const_capture_init_iterator I = S.capture_init_begin (),
334+ E = S.capture_init_end ();
335+ I != E; ++I, ++CurField) {
336+ LValue LV = EmitLValueForFieldInitialization (AggLV, *CurField);
337+ // Initialize for VLA.
338+ if (CurField->hasCapturedVLAType ()) {
339+ EmitLambdaVLACapture (CurField->getCapturedVLAType (), LV);
340+ } else
341+ // Initialize for capturesThis, capturesVariableByCopy,
342+ // capturesVariable
343+ EmitInitializerForField (*CurField, LV, *I);
344+ }
345+
346+ CapturedVars.push_back (AggLV.getPointer (*this ));
347+ }
348+
323349void CodeGenFunction::GenerateOpenMPCapturedVars (
324350 const CapturedStmt &S, SmallVectorImpl<llvm::Value *> &CapturedVars) {
325351 const RecordDecl *RD = S.getCapturedRecordDecl ();
@@ -420,6 +446,101 @@ struct FunctionOptions {
420446};
421447} // namespace
422448
449+ static llvm::Function *emitOutlinedFunctionPrologueAggregate (
450+ CodeGenFunction &CGF, FunctionArgList &Args,
451+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
452+ &LocalAddrs,
453+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>>
454+ &VLASizes,
455+ llvm::Value *&CXXThisValue, const CapturedStmt &CS, SourceLocation Loc,
456+ StringRef FunctionName) {
457+ const CapturedDecl *CD = CS.getCapturedDecl ();
458+ const RecordDecl *RD = CS.getCapturedRecordDecl ();
459+ assert (CD->hasBody () && " missing CapturedDecl body" );
460+
461+ CXXThisValue = nullptr ;
462+ // Build the argument list.
463+ CodeGenModule &CGM = CGF.CGM ;
464+ ASTContext &Ctx = CGM.getContext ();
465+ Args.append (CD->param_begin (), CD->param_end ());
466+
467+ // Create the function declaration.
468+ const CGFunctionInfo &FuncInfo =
469+ CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy , Args);
470+ llvm::FunctionType *FuncLLVMTy = CGM.getTypes ().GetFunctionType (FuncInfo);
471+
472+ auto *F =
473+ llvm::Function::Create (FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
474+ FunctionName, &CGM.getModule ());
475+ CGM.SetInternalFunctionAttributes (CD, F, FuncInfo);
476+ if (CD->isNothrow ())
477+ F->setDoesNotThrow ();
478+ F->setDoesNotRecurse ();
479+
480+ // Generate the function.
481+ CGF.StartFunction (CD, Ctx.VoidTy , F, FuncInfo, Args, Loc, Loc);
482+ Address ContextAddr = CGF.GetAddrOfLocalVar (CD->getContextParam ());
483+ llvm::Value *ContextV = CGF.Builder .CreateLoad (ContextAddr);
484+ LValue ContextLV = CGF.MakeNaturalAlignAddrLValue (
485+ ContextV, CGM.getContext ().getTagDeclType (RD));
486+ auto I = CS.captures ().begin ();
487+ for (const FieldDecl *FD : RD->fields ()) {
488+ LValue FieldLV = CGF.EmitLValueForFieldInitialization (ContextLV, FD);
489+ // Do not map arguments if we emit function with non-original types.
490+ Address LocalAddr = FieldLV.getAddress (CGF);
491+ // If we are capturing a pointer by copy we don't need to do anything, just
492+ // use the value that we get from the arguments.
493+ if (I->capturesVariableByCopy () && FD->getType ()->isAnyPointerType ()) {
494+ const VarDecl *CurVD = I->getCapturedVar ();
495+ LocalAddrs.insert ({FD, {CurVD, LocalAddr}});
496+ ++I;
497+ continue ;
498+ }
499+
500+ LValue ArgLVal =
501+ CGF.MakeAddrLValue (LocalAddr, FD->getType (), AlignmentSource::Decl);
502+ if (FD->hasCapturedVLAType ()) {
503+ llvm::Value *ExprArg = CGF.EmitLoadOfScalar (ArgLVal, I->getLocation ());
504+ const VariableArrayType *VAT = FD->getCapturedVLAType ();
505+ VLASizes.try_emplace (FD, VAT->getSizeExpr (), ExprArg);
506+ } else if (I->capturesVariable ()) {
507+ const VarDecl *Var = I->getCapturedVar ();
508+ QualType VarTy = Var->getType ();
509+ Address ArgAddr = ArgLVal.getAddress (CGF);
510+ if (ArgLVal.getType ()->isLValueReferenceType ()) {
511+ ArgAddr = CGF.EmitLoadOfReference (ArgLVal);
512+ } else if (!VarTy->isVariablyModifiedType () || !VarTy->isPointerType ()) {
513+ assert (ArgLVal.getType ()->isPointerType ());
514+ ArgAddr = CGF.EmitLoadOfPointer (
515+ ArgAddr, ArgLVal.getType ()->castAs <PointerType>());
516+ }
517+ LocalAddrs.insert (
518+ {FD, {Var, Address (ArgAddr.getPointer (), Ctx.getDeclAlign (Var))}});
519+ } else if (I->capturesVariableByCopy ()) {
520+ assert (!FD->getType ()->isAnyPointerType () &&
521+ " Not expecting a captured pointer." );
522+ const VarDecl *Var = I->getCapturedVar ();
523+ Address CopyAddr = CGF.CreateMemTemp (FD->getType (), Ctx.getDeclAlign (FD),
524+ Var->getName ());
525+ LValue CopyLVal =
526+ CGF.MakeAddrLValue (CopyAddr, FD->getType (), AlignmentSource::Decl);
527+
528+ RValue ArgRVal = CGF.EmitLoadOfLValue (ArgLVal, I->getLocation ());
529+ CGF.EmitStoreThroughLValue (ArgRVal, CopyLVal);
530+
531+ LocalAddrs.insert ({FD, {Var, CopyAddr}});
532+ } else {
533+ // If 'this' is captured, load it into CXXThisValue.
534+ assert (I->capturesThis ());
535+ CXXThisValue = CGF.EmitLoadOfScalar (ArgLVal, I->getLocation ());
536+ LocalAddrs.insert ({FD, {nullptr , ArgLVal.getAddress (CGF)}});
537+ }
538+ ++I;
539+ }
540+
541+ return F;
542+ }
543+
423544static llvm::Function *emitOutlinedFunctionPrologue (
424545 CodeGenFunction &CGF, FunctionArgList &Args,
425546 llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
@@ -595,6 +716,37 @@ static llvm::Function *emitOutlinedFunctionPrologue(
595716 return F;
596717}
597718
719+ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunctionAggregate (
720+ const CapturedStmt &S, SourceLocation Loc) {
721+ assert (
722+ CapturedStmtInfo &&
723+ " CapturedStmtInfo should be set when generating the captured function" );
724+ const CapturedDecl *CD = S.getCapturedDecl ();
725+ // Build the argument list.
726+ FunctionArgList Args;
727+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
728+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
729+ StringRef FunctionName = CapturedStmtInfo->getHelperName ();
730+ llvm::Function *F = emitOutlinedFunctionPrologueAggregate (
731+ *this , Args, LocalAddrs, VLASizes, CXXThisValue, S, Loc, FunctionName);
732+ CodeGenFunction::OMPPrivateScope LocalScope (*this );
733+ for (const auto &LocalAddrPair : LocalAddrs) {
734+ if (LocalAddrPair.second .first ) {
735+ LocalScope.addPrivate (LocalAddrPair.second .first , [&LocalAddrPair]() {
736+ return LocalAddrPair.second .second ;
737+ });
738+ }
739+ }
740+ (void )LocalScope.Privatize ();
741+ for (const auto &VLASizePair : VLASizes)
742+ VLASizeMap[VLASizePair.second .first ] = VLASizePair.second .second ;
743+ PGO.assignRegionCounters (GlobalDecl (CD), F);
744+ CapturedStmtInfo->EmitBody (*this , CD->getBody ());
745+ (void )LocalScope.ForceCleanup ();
746+ FinishFunction (CD->getBodyRBrace ());
747+ return F;
748+ }
749+
598750llvm::Function *
599751CodeGenFunction::GenerateOpenMPCapturedStmtFunction (const CapturedStmt &S,
600752 SourceLocation Loc) {
@@ -1582,7 +1734,7 @@ static void emitCommonOMPParallelDirective(
15821734 // The following lambda takes care of appending the lower and upper bound
15831735 // parameters when necessary
15841736 CodeGenBoundParameters (CGF, S, CapturedVars);
1585- CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
1737+ CGF.GenerateOpenMPCapturedVarsAggregate (*CS, CapturedVars);
15861738 CGF.CGM .getOpenMPRuntime ().emitParallelCall (CGF, S.getBeginLoc (), OutlinedFn,
15871739 CapturedVars, IfCond);
15881740}
@@ -6050,7 +6202,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
60506202
60516203 OMPTeamsScope Scope (CGF, S);
60526204 llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
6053- CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
6205+ CGF.GenerateOpenMPCapturedVarsAggregate (*CS, CapturedVars);
60546206 CGF.CGM .getOpenMPRuntime ().emitTeamsCall (CGF, S, S.getBeginLoc (), OutlinedFn,
60556207 CapturedVars);
60566208}
0 commit comments