@@ -1039,9 +1039,11 @@ class AsyncPartialApplicationForwarderEmission
10391039 : public PartialApplicationForwarderEmission {
10401040 using super = PartialApplicationForwarderEmission;
10411041 AsyncContextLayout layout;
1042- llvm::Value *contextBuffer;
1042+ llvm::Value *calleeFunction;
1043+ llvm::Value *currentResumeFn;
10431044 Size contextSize;
10441045 Address context;
1046+ Address calleeContextBuffer;
10451047 unsigned currentArgumentIndex;
10461048 struct DynamicFunction {
10471049 using Kind = DynamicFunctionKind;
@@ -1060,23 +1062,11 @@ class AsyncPartialApplicationForwarderEmission
10601062 };
10611063 Optional<Self> self = llvm::None;
10621064
1063- llvm::Value *loadValue (ElementLayout layout) {
1064- Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
1065- auto &ti = cast<LoadableTypeInfo>(layout.getType ());
1066- Explosion explosion;
1067- ti.loadAsTake (subIGF, addr, explosion);
1068- return explosion.claimNext ();
1069- }
10701065 void saveValue (ElementLayout layout, Explosion &explosion) {
10711066 Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
10721067 auto &ti = cast<LoadableTypeInfo>(layout.getType ());
10731068 ti.initialize (subIGF, explosion, addr, /* isOutlined*/ false );
10741069 }
1075- void loadValue (ElementLayout layout, Explosion &explosion) {
1076- Address addr = layout.project (subIGF, context, /* offsets*/ llvm::None);
1077- auto &ti = cast<LoadableTypeInfo>(layout.getType ());
1078- ti.loadAsTake (subIGF, addr, explosion);
1079- }
10801070
10811071public:
10821072 AsyncPartialApplicationForwarderEmission (
@@ -1099,9 +1089,59 @@ class AsyncPartialApplicationForwarderEmission
10991089 void begin () override { super::begin (); }
11001090
11011091 void mapAsyncParameters () override {
1102- contextBuffer = origParams.claimNext ();
1103- context = layout.emitCastTo (subIGF, contextBuffer);
1104- args.add (contextBuffer);
1092+ // Ignore the original context.
1093+ (void )origParams.claimNext ();
1094+
1095+ llvm::Value *dynamicContextSize32;
1096+ auto initialContextSize = Size (0 );
1097+ std::tie (calleeFunction, dynamicContextSize32) = getAsyncFunctionAndSize (
1098+ subIGF, origType->getRepresentation (), *staticFnPtr,
1099+ nullptr , std::make_pair (true , true ), initialContextSize);
1100+ auto *dynamicContextSize =
1101+ subIGF.Builder .CreateZExt (dynamicContextSize32, subIGF.IGM .SizeTy );
1102+ calleeContextBuffer =
1103+ emitAllocAsyncContext (subIGF, dynamicContextSize);
1104+ context = layout.emitCastTo (subIGF, calleeContextBuffer.getAddress ());
1105+ auto calleeContext =
1106+ layout.emitCastTo (subIGF, calleeContextBuffer.getAddress ());
1107+ args.add (subIGF.Builder .CreateBitOrPointerCast (
1108+ calleeContextBuffer.getAddress (), IGM.SwiftContextPtrTy ));
1109+
1110+ // Set caller info into the context.
1111+ { // caller context
1112+ Explosion explosion;
1113+ auto fieldLayout = layout.getParentLayout ();
1114+ auto *context = subIGF.getAsyncContext ();
1115+ if (auto schema =
1116+ subIGF.IGM .getOptions ().PointerAuth .AsyncContextParent ) {
1117+ Address fieldAddr =
1118+ fieldLayout.project (subIGF, calleeContext, /* offsets*/ llvm::None);
1119+ auto authInfo = PointerAuthInfo::emit (
1120+ subIGF, schema, fieldAddr.getAddress (), PointerAuthEntity ());
1121+ context = emitPointerAuthSign (subIGF, context, authInfo);
1122+ }
1123+ explosion.add (context);
1124+ saveValue (fieldLayout, explosion);
1125+ }
1126+ { // Return to caller function.
1127+ auto fieldLayout = layout.getResumeParentLayout ();
1128+ currentResumeFn = subIGF.Builder .CreateIntrinsicCall (
1129+ llvm::Intrinsic::coro_async_resume, {});
1130+ auto fnVal = currentResumeFn;
1131+ // Sign the pointer.
1132+ if (auto schema = subIGF.IGM .getOptions ().PointerAuth .AsyncContextResume ) {
1133+ Address fieldAddr =
1134+ fieldLayout.project (subIGF, calleeContext, /* offsets*/ llvm::None);
1135+ auto authInfo = PointerAuthInfo::emit (
1136+ subIGF, schema, fieldAddr.getAddress (), PointerAuthEntity ());
1137+ fnVal = emitPointerAuthSign (subIGF, fnVal, authInfo);
1138+ }
1139+ fnVal = subIGF.Builder .CreateBitCast (
1140+ fnVal, subIGF.IGM .TaskContinuationFunctionPtrTy );
1141+ Explosion explosion;
1142+ explosion.add (fnVal);
1143+ saveValue (fieldLayout, explosion);
1144+ }
11051145 }
11061146 void gatherArgumentsFromApply () override {
11071147 super::gatherArgumentsFromApply (true );
@@ -1127,13 +1167,87 @@ class AsyncPartialApplicationForwarderEmission
11271167 // Nothing to do here. The error result pointer is already in the
11281168 // appropriate position.
11291169 }
1170+ FunctionPointer getFunctionPointerForDispatchCall (const FunctionPointer &fn) {
1171+ auto &IGM = subIGF.IGM ;
1172+ // Strip off the return type. The original function pointer signature
1173+ // captured both the entry point type and the resume function type.
1174+ auto *fnTy = llvm::FunctionType::get (
1175+ IGM.VoidTy , fn.getSignature ().getType ()->params (), false /* vaargs*/ );
1176+ auto signature =
1177+ Signature (fnTy, fn.getSignature ().getAttributes (), IGM.SwiftAsyncCC );
1178+ auto fnPtr =
1179+ FunctionPointer (FunctionPointer::Kind::Function, fn.getRawPointer (),
1180+ fn.getAuthInfo (), signature);
1181+ return fnPtr;
1182+ }
11301183 llvm::CallInst *createCall (FunctionPointer &fnPtr) override {
1131- return subIGF.Builder .CreateCall (fnPtr.getAsFunction (subIGF),
1132- args.claimAll ());
1184+ auto newFnPtr = FunctionPointer (
1185+ FunctionPointer::Kind::Function, fnPtr.getPointer (subIGF),
1186+ fnPtr.getAuthInfo (), Signature::forAsyncAwait (subIGF.IGM , origType));
1187+ auto &Builder = subIGF.Builder ;
1188+
1189+ auto argValues = args.claimAll ();
1190+
1191+ // Setup the suspend point.
1192+ SmallVector<llvm::Value *, 8 > arguments;
1193+ auto signature = newFnPtr.getSignature ();
1194+ auto asyncContextIndex = signature.getAsyncContextIndex ();
1195+ auto paramAttributeFlags =
1196+ asyncContextIndex |
1197+ (signature.getAsyncResumeFunctionSwiftSelfIndex () << 8 );
1198+ // Index of swiftasync context | ((index of swiftself) << 8).
1199+ arguments.push_back (
1200+ IGM.getInt32 (paramAttributeFlags));
1201+ arguments.push_back (currentResumeFn);
1202+ auto resumeProjFn = subIGF.getOrCreateResumePrjFn ();
1203+ arguments.push_back (
1204+ Builder.CreateBitOrPointerCast (resumeProjFn, IGM.Int8PtrTy ));
1205+ auto dispatchFn = subIGF.createAsyncDispatchFn (
1206+ getFunctionPointerForDispatchCall (newFnPtr), argValues);
1207+ arguments.push_back (
1208+ Builder.CreateBitOrPointerCast (dispatchFn, IGM.Int8PtrTy ));
1209+ arguments.push_back (
1210+ Builder.CreateBitOrPointerCast (newFnPtr.getRawPointer (), IGM.Int8PtrTy ));
1211+ if (auto authInfo = newFnPtr.getAuthInfo ()) {
1212+ arguments.push_back (newFnPtr.getAuthInfo ().getDiscriminator ());
1213+ }
1214+ for (auto arg : argValues)
1215+ arguments.push_back (arg);
1216+ auto resultTy =
1217+ cast<llvm::StructType>(signature.getType ()->getReturnType ());
1218+ return subIGF.emitSuspendAsyncCall (asyncContextIndex, resultTy, arguments);
11331219 }
11341220 void createReturn (llvm::CallInst *call) override {
1135- call->setTailCallKind (IGM.AsyncTailCallKind );
1136- subIGF.Builder .CreateRetVoid ();
1221+ emitDeallocAsyncContext (subIGF, calleeContextBuffer);
1222+ auto numAsyncContextParams =
1223+ Signature::forAsyncReturn (IGM, outType).getAsyncContextIndex () + 1 ;
1224+ llvm::Value *result = call;
1225+ auto *suspendResultTy = cast<llvm::StructType>(result->getType ());
1226+ Explosion resultExplosion;
1227+ Explosion errorExplosion;
1228+ SILFunctionConventions conv (outType, subIGF.getSILModule ());
1229+ auto hasError = outType->hasErrorResult ();
1230+
1231+ Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
1232+ SmallVector<llvm::Value *, 16 > nativeResultsStorage;
1233+
1234+ if (suspendResultTy->getNumElements () == numAsyncContextParams) {
1235+ // no result to forward.
1236+ assert (!hasError);
1237+ } else {
1238+ auto &Builder = subIGF.Builder ;
1239+ auto resultTys =
1240+ makeArrayRef (suspendResultTy->element_begin () + numAsyncContextParams,
1241+ suspendResultTy->element_end ());
1242+
1243+ for (unsigned i = 0 , e = resultTys.size (); i != e; ++i) {
1244+ llvm::Value *elt =
1245+ Builder.CreateExtractValue (result, numAsyncContextParams + i);
1246+ nativeResultsStorage.push_back (elt);
1247+ }
1248+ nativeResults = nativeResultsStorage;
1249+ }
1250+ emitAsyncReturn (subIGF, layout, origType, nativeResults);
11371251 }
11381252 void end () override {
11391253 assert (context.isValid ());
@@ -1180,6 +1294,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
11801294 llvm::AttributeList outAttrs = outSig.getAttributes ();
11811295 llvm::FunctionType *fwdTy = outSig.getType ();
11821296 SILFunctionConventions outConv (outType, IGM.getSILModule ());
1297+ Optional<AsyncContextLayout> asyncLayout;
11831298
11841299 StringRef FnName;
11851300 if (staticFnPtr)
@@ -1203,19 +1318,29 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
12031318
12041319 IRGenFunction subIGF (IGM, fwd);
12051320 if (origType->isAsync ()) {
1206- subIGF.setupAsync (
1207- Signature::forAsyncEntry (IGM, outType).getAsyncContextIndex ());
1321+ auto asyncContextIdx =
1322+ Signature::forAsyncEntry (IGM, outType).getAsyncContextIndex ();
1323+ asyncLayout.emplace (irgen::getAsyncContextLayout (
1324+ IGM, origType, substType, subs, /* suppress generics*/ false ,
1325+ FunctionPointer::Kind (
1326+ FunctionPointer::BasicKind::AsyncFunctionPointer)));
1327+
1328+ subIGF.setupAsync (asyncContextIdx);
12081329
1209- auto *calleeAFP = staticFnPtr->getDirectPointer ();
1330+ // auto *calleeAFP = staticFnPtr->getDirectPointer();
12101331 LinkEntity entity = LinkEntity::forPartialApplyForwarder (fwd);
1211- auto size = Size (0 );
12121332 assert (!asyncFunctionPtr &&
12131333 " already had an async function pointer to the forwarder?!" );
1214- asyncFunctionPtr = emitAsyncFunctionPointer (IGM, fwd, entity, size);
1334+ emitAsyncFunctionEntry (subIGF, *asyncLayout, entity, asyncContextIdx);
1335+ asyncFunctionPtr =
1336+ emitAsyncFunctionPointer (IGM, fwd, entity, asyncLayout->getSize ());
1337+ // TODO: if calleeAFP is definition:
1338+ #if 0
12151339 subIGF.Builder.CreateIntrinsicCall(
12161340 llvm::Intrinsic::coro_async_size_replace,
12171341 {subIGF.Builder.CreateBitCast(asyncFunctionPtr, IGM.Int8PtrTy),
12181342 subIGF.Builder.CreateBitCast(calleeAFP, IGM.Int8PtrTy)});
1343+ #endif
12191344 }
12201345 if (IGM.DebugInfo )
12211346 IGM.DebugInfo ->emitArtificialFunction (subIGF, fwd);
@@ -1679,7 +1804,7 @@ static llvm::Value *emitPartialApplicationForwarder(IRGenModule &IGM,
16791804
16801805 llvm::CallInst *call = emission->createCall (fnPtr);
16811806
1682- if (addressesToDeallocate.empty () && !needsAllocas &&
1807+ if (!origType-> isAsync () && addressesToDeallocate.empty () && !needsAllocas &&
16831808 (!consumesContext || !dependsOnContextLifetime))
16841809 call->setTailCall ();
16851810
0 commit comments