@@ -4195,6 +4195,13 @@ struct AsyncHandlerDesc {
41954195 return params ();
41964196 }
41974197
4198+ // / If the completion handler has an Error parameter, return it.
4199+ Optional<AnyFunctionType::Param> getErrorParam () const {
4200+ if (HasError && Type == HandlerType::PARAMS)
4201+ return params ().back ();
4202+ return None;
4203+ }
4204+
41984205 // / Get the type of the error that will be thrown by the \c async method or \c
41994206 // / None if the completion handler doesn't accept an error parameter.
42004207 // / This may be more specialized than the generic 'Error' type if the
@@ -5397,6 +5404,41 @@ class AsyncConverter : private SourceEntityWalker {
53975404 return true ;
53985405 }
53995406
5407+ // / Creates an async alternative function that forwards onto the completion
5408+ // / handler function through
5409+ // / withCheckedContinuation/withCheckedThrowingContinuation.
5410+ bool createAsyncWrapper () {
5411+ assert (Buffer.empty () && " AsyncConverter can only be used once" );
5412+ auto *FD = cast<FuncDecl>(StartNode.get <Decl *>());
5413+
5414+ // First add the new async function declaration.
5415+ addFuncDecl (FD);
5416+ OS << tok::l_brace << " \n " ;
5417+
5418+ // Then add the body.
5419+ OS << tok::kw_return << " " ;
5420+ if (TopHandler.HasError )
5421+ OS << tok::kw_try << " " ;
5422+
5423+ OS << " await " ;
5424+
5425+ // withChecked[Throwing]Continuation { cont in
5426+ if (TopHandler.HasError ) {
5427+ OS << " withCheckedThrowingContinuation" ;
5428+ } else {
5429+ OS << " withCheckedContinuation" ;
5430+ }
5431+ OS << " " << tok::l_brace << " cont " << tok::kw_in << " \n " ;
5432+
5433+ // fnWithHandler(args...) { ... }
5434+ auto ClosureStr = getAsyncWrapperCompletionClosure (" cont" , TopHandler);
5435+ addForwardingCallTo (FD, TopHandler, /* HandlerReplacement*/ ClosureStr);
5436+
5437+ OS << tok::r_brace << " \n " ; // end continuation closure
5438+ OS << tok::r_brace << " \n " ; // end function body
5439+ return true ;
5440+ }
5441+
54005442 void replace (ASTNode Node, SourceEditConsumer &EditConsumer,
54015443 SourceLoc StartOverride = SourceLoc()) {
54025444 SourceRange Range = Node.getSourceRange ();
@@ -5432,6 +5474,130 @@ class AsyncConverter : private SourceEntityWalker {
54325474 return TopHandler.isValid ();
54335475 }
54345476
5477+ // / Prints a tuple of elements, or a lone single element if only one is
5478+ // / present, using the provided printing function.
5479+ template <typename T, typename PrintFn>
5480+ void addTupleOf (ArrayRef<T> Elements, llvm::raw_ostream &OS,
5481+ PrintFn PrintElt) {
5482+ if (Elements.size () == 1 ) {
5483+ PrintElt (Elements[0 ]);
5484+ return ;
5485+ }
5486+ OS << tok::l_paren;
5487+ llvm::interleave (Elements, PrintElt, [&]() { OS << tok::comma << " " ; });
5488+ OS << tok::r_paren;
5489+ }
5490+
5491+ // / Retrieve the completion handler closure argument for an async wrapper
5492+ // / function.
5493+ std::string
5494+ getAsyncWrapperCompletionClosure (StringRef ContName,
5495+ const AsyncHandlerParamDesc &HandlerDesc) {
5496+ std::string OutputStr;
5497+ llvm::raw_string_ostream OS (OutputStr);
5498+
5499+ OS << " " << tok::l_brace; // start closure
5500+
5501+ // Prepare parameter names for the closure.
5502+ auto SuccessParams = HandlerDesc.getSuccessParams ();
5503+ SmallVector<SmallString<4 >, 2 > SuccessParamNames;
5504+ for (auto idx : indices (SuccessParams)) {
5505+ SuccessParamNames.emplace_back (" res" );
5506+
5507+ // If we have multiple success params, number them e.g res1, res2...
5508+ if (SuccessParams.size () > 1 )
5509+ SuccessParamNames.back ().append (std::to_string (idx + 1 ));
5510+ }
5511+ Optional<SmallString<4 >> ErrName;
5512+ if (HandlerDesc.getErrorParam ())
5513+ ErrName.emplace (" err" );
5514+
5515+ auto HasAnyParams = !SuccessParamNames.empty () || ErrName;
5516+ if (HasAnyParams)
5517+ OS << " " ;
5518+
5519+ // res1, res2
5520+ llvm::interleave (
5521+ SuccessParamNames, [&](auto Name) { OS << Name; },
5522+ [&]() { OS << tok::comma << " " ; });
5523+
5524+ // , err
5525+ if (ErrName) {
5526+ if (!SuccessParamNames.empty ())
5527+ OS << tok::comma << " " ;
5528+
5529+ OS << *ErrName;
5530+ }
5531+ if (HasAnyParams)
5532+ OS << " " << tok::kw_in;
5533+
5534+ OS << " \n " ;
5535+
5536+ // The closure body.
5537+ switch (HandlerDesc.Type ) {
5538+ case HandlerType::PARAMS: {
5539+ // For a (Success?, Error?) -> Void handler, we do an if let on the error.
5540+ if (ErrName) {
5541+ // if let err = err {
5542+ OS << tok::kw_if << " " << tok::kw_let << " " ;
5543+ OS << *ErrName << " " << tok::equal << " " << *ErrName << " " ;
5544+ OS << tok::l_brace << " \n " ;
5545+
5546+ // cont.resume(throwing: err)
5547+ OS << ContName << tok::period << " resume" << tok::l_paren;
5548+ OS << " throwing" << tok::colon << " " << *ErrName;
5549+ OS << tok::r_paren << " \n " ;
5550+
5551+ // return }
5552+ OS << tok::kw_return << " \n " ;
5553+ OS << tok::r_brace << " \n " ;
5554+ }
5555+
5556+ // If we have any success params that we need to unwrap, insert a guard.
5557+ for (auto Idx : indices (SuccessParamNames)) {
5558+ auto &Name = SuccessParamNames[Idx];
5559+ auto ParamTy = SuccessParams[Idx].getParameterType ();
5560+ if (!HandlerDesc.shouldUnwrap (ParamTy))
5561+ continue ;
5562+
5563+ // guard let res = res else {
5564+ OS << tok::kw_guard << " " << tok::kw_let << " " ;
5565+ OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5566+ OS << " " << tok::l_brace << " \n " ;
5567+
5568+ // fatalError(...)
5569+ OS << " fatalError" << tok::l_paren;
5570+ OS << " \" Expected non-nil success param '" << Name;
5571+ OS << " ' for nil error\" " ;
5572+ OS << tok::r_paren << " \n " ;
5573+
5574+ // End guard.
5575+ OS << tok::r_brace << " \n " ;
5576+ }
5577+
5578+ // cont.resume(returning: (res1, res2, ...))
5579+ OS << ContName << tok::period << " resume" << tok::l_paren;
5580+ OS << " returning" << tok::colon << " " ;
5581+ addTupleOf (llvm::makeArrayRef (SuccessParamNames), OS,
5582+ [&](auto Ref) { OS << Ref; });
5583+ OS << tok::r_paren << " \n " ;
5584+ break ;
5585+ }
5586+ case HandlerType::RESULT: {
5587+ // cont.resume(with: res)
5588+ assert (SuccessParamNames.size () == 1 );
5589+ OS << ContName << tok::period << " resume" << tok::l_paren;
5590+ OS << " with" << tok::colon << " " << SuccessParamNames[0 ];
5591+ OS << tok::r_paren << " \n " ;
5592+ break ;
5593+ }
5594+ case HandlerType::INVALID:
5595+ llvm_unreachable (" Should not have an invalid handler here" );
5596+ }
5597+
5598+ OS << tok::r_brace << " \n " ; // end closure
5599+ return OutputStr;
5600+ }
54355601
54365602 // / Retrieves the location for the start of a comment attached to the token
54375603 // / at the provided location, or the location itself if there is no comment.
@@ -6075,16 +6241,9 @@ class AsyncConverter : private SourceEntityWalker {
60756241 }
60766242 OS << " " ;
60776243 }
6078- if (SuccessParams.size () > 1 )
6079- OS << tok::l_paren;
6080- OS << newNameFor (SuccessParams.front ());
6081- for (const auto Param : SuccessParams.drop_front ()) {
6082- OS << tok::comma << " " ;
6083- OS << newNameFor (Param);
6084- }
6085- if (SuccessParams.size () > 1 ) {
6086- OS << tok::r_paren;
6087- }
6244+ // 'res =' or '(res1, res2, ...) ='
6245+ addTupleOf (SuccessParams, OS,
6246+ [&](auto &Param) { OS << newNameFor (Param); });
60886247 OS << " " << tok::equal << " " ;
60896248 }
60906249
@@ -6271,22 +6430,46 @@ class AsyncConverter : private SourceEntityWalker {
62716430 // / 'await' keyword.
62726431 void addCallToAsyncMethod (const FuncDecl *FD,
62736432 const AsyncHandlerDesc &HandlerDesc) {
6433+ // The call to the async function is the same as the call to the old
6434+ // completion handler function, minus the completion handler arg.
6435+ addForwardingCallTo (FD, HandlerDesc, /* HandlerReplacement*/ " " );
6436+ }
6437+
6438+ // / Adds a forwarding call to the old completion handler function, with
6439+ // / \p HandlerReplacement that allows for a custom replacement or, if empty,
6440+ // / removal of the completion handler closure.
6441+ void addForwardingCallTo (
6442+ const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc,
6443+ StringRef HandlerReplacement, bool CanUseTrailingClosure = true ) {
62746444 OS << FD->getBaseName () << tok::l_paren;
6275- bool FirstParam = true ;
6276- for (auto Param : *FD->getParameters ()) {
6445+
6446+ auto *Params = FD->getParameters ();
6447+ for (auto Param : *Params) {
62776448 if (Param == HandlerDesc.getHandler ()) {
6278- // / We don't need to pass the completion handler to the async method.
6279- continue ;
6449+ // / If we're not replacing the handler with anything, drop it.
6450+ if (HandlerReplacement.empty ())
6451+ continue ;
6452+
6453+ // If this is the last param, and we can use a trailing closure, do so.
6454+ if (CanUseTrailingClosure && Param == Params->back ()) {
6455+ OS << tok::r_paren << " " ;
6456+ OS << HandlerReplacement;
6457+ return ;
6458+ }
6459+ // Otherwise fall through to do the replacement.
62806460 }
6281- if (!FirstParam) {
6461+
6462+ if (Param != Params->front ())
62826463 OS << tok::comma << " " ;
6283- } else {
6284- FirstParam = false ;
6285- }
6286- if (!Param->getArgumentName ().empty ()) {
6464+
6465+ if (!Param->getArgumentName ().empty ())
62876466 OS << Param->getArgumentName () << tok::colon << " " ;
6467+
6468+ if (Param == HandlerDesc.getHandler ()) {
6469+ OS << HandlerReplacement;
6470+ } else {
6471+ OS << Param->getParameterName ();
62886472 }
6289- OS << Param->getParameterName ();
62906473 }
62916474 OS << tok::r_paren;
62926475 }
@@ -6408,19 +6591,10 @@ class AsyncConverter : private SourceEntityWalker {
64086591 // / Adds the result type of a refactored async function that previously
64096592 // / returned results via a completion handler described by \p HandlerDesc.
64106593 void addAsyncFuncReturnType (const AsyncHandlerDesc &HandlerDesc) {
6594+ // Type or (Type1, Type2, ...)
64116595 SmallVector<Type, 2 > Scratch;
6412- auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
6413- if (ReturnTypes.size () > 1 ) {
6414- OS << tok::l_paren;
6415- }
6416-
6417- llvm::interleave (
6418- ReturnTypes, [&](Type Ty) { Ty->print (OS); },
6419- [&]() { OS << tok::comma << " " ; });
6420-
6421- if (ReturnTypes.size () > 1 ) {
6422- OS << tok::r_paren;
6423- }
6596+ addTupleOf (HandlerDesc.getAsyncReturnTypes (Scratch), OS,
6597+ [&](auto Ty) { Ty->print (OS); });
64246598 }
64256599
64266600 // / If \p FD is generic, adds a type annotation with the return type of the
@@ -6450,6 +6624,24 @@ class AsyncConverter : private SourceEntityWalker {
64506624 }
64516625};
64526626
6627+ // / Adds an attribute to describe a completion handler function's async
6628+ // / alternative if necessary.
6629+ void addCompletionHandlerAsyncAttrIfNeccessary (
6630+ ASTContext &Ctx, const FuncDecl *FD,
6631+ const AsyncHandlerParamDesc &HandlerDesc,
6632+ SourceEditConsumer &EditConsumer) {
6633+ if (!Ctx.LangOpts .EnableExperimentalConcurrency )
6634+ return ;
6635+
6636+ llvm::SmallString<0 > HandlerAttribute;
6637+ llvm::raw_svector_ostream OS (HandlerAttribute);
6638+ OS << " @completionHandlerAsync(\" " ;
6639+ HandlerDesc.printAsyncFunctionName (OS);
6640+ OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6641+ EditConsumer.accept (Ctx.SourceMgr , FD->getAttributeInsertionLoc (false ),
6642+ HandlerAttribute);
6643+ }
6644+
64536645} // namespace asyncrefactorings
64546646
64556647bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -6571,16 +6763,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65716763 " @available(*, deprecated, message: \" Prefer async "
65726764 " alternative instead\" )\n " );
65736765
6574- if (Ctx.LangOpts .EnableExperimentalConcurrency ) {
6575- // Add an attribute to describe its async alternative
6576- llvm::SmallString<0 > HandlerAttribute;
6577- llvm::raw_svector_ostream OS (HandlerAttribute);
6578- OS << " @completionHandlerAsync(\" " ;
6579- HandlerDesc.printAsyncFunctionName (OS);
6580- OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6581- EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
6582- HandlerAttribute);
6583- }
6766+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
65846767
65856768 AsyncConverter LegacyBodyCreator (TheFile, SM, DiagEngine, FD, HandlerDesc);
65866769 if (LegacyBodyCreator.createLegacyBody ()) {
@@ -6592,6 +6775,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65926775
65936776 return false ;
65946777}
6778+
6779+ bool RefactoringActionAddAsyncWrapper::isApplicable (
6780+ const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6781+ using namespace asyncrefactorings ;
6782+
6783+ auto *FD = findFunction (CursorInfo);
6784+ if (!FD)
6785+ return false ;
6786+
6787+ auto HandlerDesc =
6788+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6789+ return HandlerDesc.isValid ();
6790+ }
6791+
6792+ bool RefactoringActionAddAsyncWrapper::performChange () {
6793+ using namespace asyncrefactorings ;
6794+
6795+ auto *FD = findFunction (CursorInfo);
6796+ assert (FD &&
6797+ " Should not run performChange when refactoring is not applicable" );
6798+
6799+ auto HandlerDesc =
6800+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6801+ assert (HandlerDesc.isValid () &&
6802+ " Should not run performChange when refactoring is not applicable" );
6803+
6804+ AsyncConverter Converter (TheFile, SM, DiagEngine, FD, HandlerDesc);
6805+ if (!Converter.createAsyncWrapper ())
6806+ return true ;
6807+
6808+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
6809+
6810+ // Add the async wrapper.
6811+ Converter.insertAfter (FD, EditConsumer);
6812+ return false ;
6813+ }
6814+
65956815} // end of anonymous namespace
65966816
65976817StringRef swift::ide::
0 commit comments