From fc19b8e1f412d35848f8e36664b58a60785a2d69 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Fri, 10 Nov 2023 15:14:12 -0800 Subject: [PATCH 1/2] [Typed throws] Handle error conversions in witness thunks --- lib/SILGen/SILGenPoly.cpp | 62 ++++++++++++++++--------- lib/SILGen/SILGenType.cpp | 2 +- test/SILGen/typed_throws_generic.swift | 63 ++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 22 deletions(-) diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 3c5dc2d384b5e..0a11d90654d41 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -891,6 +891,37 @@ void SILGenFunction::collectThunkParams( } } +/// If the inner function we are calling (with type \c fnType) from the thunk +/// created by \c SGF requires an indirect error argument, returns that +/// argument. +static llvm::Optional +emitThunkIndirectErrorArgument(SILGenFunction &SGF, SILLocation loc, + CanSILFunctionType fnType) { + // If the function we're calling has as indirect error result, create an + // argument for it. + auto innerError = fnType->getOptionalErrorResult(); + if (!innerError || innerError->getConvention() != ResultConvention::Indirect) + return llvm::None; + + // If the type of the indirect error is the same for both the inner + // function and the thunk, so we can re-use the indirect error slot. + auto loweredErrorResultType = SGF.getSILType(*innerError, fnType); + if (SGF.IndirectErrorResult && + SGF.IndirectErrorResult->getType().getObjectType() + == loweredErrorResultType) { + return SGF.IndirectErrorResult; + } + + // The type of the indirect error in the inner function differs from + // that of the thunk, or the thunk has a direct error, so allocate a + // stack location for the inner indirect error. + SILValue innerIndirectErrorAddr = + SGF.B.createAllocStack(loc, loweredErrorResultType); + SGF.enterDeallocStackCleanup(innerIndirectErrorAddr); + + return innerIndirectErrorAddr; +} + namespace { class TranslateIndirect : public Cleanup { @@ -4847,27 +4878,9 @@ static void buildThunkBody(SILGenFunction &SGF, SILLocation loc, // If the function we're calling has as indirect error result, create an // argument for it. - SILValue innerIndirectErrorAddr; - if (auto innerError = fnType->getOptionalErrorResult()) { - if (innerError->getConvention() == ResultConvention::Indirect) { - auto loweredErrorResultType = SGF.getSILType(*innerError, fnType); - if (SGF.IndirectErrorResult && - SGF.IndirectErrorResult->getType().getObjectType() - == loweredErrorResultType) { - // The type of the indirect error is the same for both the inner - // function and the thunk, so we can re-use the indirect error slot. - innerIndirectErrorAddr = SGF.IndirectErrorResult; - } else { - // The type of the indirect error in the inner function differs from - // that of the thunk, or the thunk has a direct error, so allocate a - // stack location for the inner indirect error. - innerIndirectErrorAddr = - SGF.B.createAllocStack(loc, loweredErrorResultType); - SGF.enterDeallocStackCleanup(innerIndirectErrorAddr); - } - - argValues.push_back(innerIndirectErrorAddr); - } + if (auto innerIndirectErrorAddr = + emitThunkIndirectErrorArgument(SGF, loc, fnType)) { + argValues.push_back(*innerIndirectErrorAddr); } // Add the rest of the arguments. @@ -6573,6 +6586,13 @@ void SILGenFunction::emitProtocolWitness( witnessFTy, thunkTy); } + // If the function we're calling has as indirect error result, create an + // argument for it. + if (auto innerIndirectErrorAddr = + emitThunkIndirectErrorArgument(*this, loc, witnessFTy)) { + args.push_back(*innerIndirectErrorAddr); + } + // - the rest of the arguments forwardFunctionArguments(*this, loc, witnessFTy, witnessParams, args); diff --git a/lib/SILGen/SILGenType.cpp b/lib/SILGen/SILGenType.cpp index 1735cd25f0bb9..8238aabcdf95f 100644 --- a/lib/SILGen/SILGenType.cpp +++ b/lib/SILGen/SILGenType.cpp @@ -758,7 +758,7 @@ SILFunction *SILGenModule::emitProtocolWitness( CanAnyFunctionType::get(genericSig, reqtSubstTy->getParams(), reqtSubstTy.getResult(), - reqtOrigTy->getExtInfo()); + reqtSubstTy->getExtInfo()); // Coroutine lowering requires us to provide these substitutions // in order to recreate the appropriate yield types for the accessor diff --git a/test/SILGen/typed_throws_generic.swift b/test/SILGen/typed_throws_generic.swift index e28ca6f35b16d..438d8dd084278 100644 --- a/test/SILGen/typed_throws_generic.swift +++ b/test/SILGen/typed_throws_generic.swift @@ -254,3 +254,66 @@ func forcedMap(_ source: [T]) -> [U] { // CHECK: bb0(%0 : $*U, %1 : $*Never, %2 : $*T) return source.typedMap { $0 as! U } } + +// Witness thunks +protocol P { + associatedtype E: Error + func f() throws(E) +} + +struct Res: P { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic3ResVyxq_GAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0, τ_0_1 where τ_0_1 : Error> (@in_guaranteed Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1 + // CHECK: bb0(%0 : $*τ_0_1, %1 : $*Res<τ_0_0, τ_0_1>): + // CHECK: [[SELF:%.*]] = load [trivial] %1 : $*Res<τ_0_0, τ_0_1> + // CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic3ResV1fyyq_YKF : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1 + // CHECK-NEXT: [[INNER_ERROR_BOX:%.*]] = alloc_stack $τ_0_1 + // CHECK-NEXT: try_apply [[WITNESS]]<τ_0_0, τ_0_1>([[INNER_ERROR_BOX]], [[SELF]]) : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + + // CHECK: [[NORMAL_BB]] + // CHECK: dealloc_stack [[INNER_ERROR_BOX]] : $*τ_0_1 + + // CHECK: [[ERROR_BB]]: + // CHECK: throw_addr + func f() throws(Failure) { } +} + +struct TypedRes: P { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic8TypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed TypedRes<τ_0_0>) -> @error_indirect MyError + // CHECK: bb0(%0 : $*MyError, %1 : $*TypedRes<τ_0_0>) + // CHECK: [[SELF:%.*]] = load [trivial] %1 : $*TypedRes<τ_0_0> + // CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic8TypedResV1fyyAA7MyErrorOYKF : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError + // CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + + // CHECK: [[NORMAL_BB]] + // CHECK: return + + // CHECK: [[ERROR_BB]]([[ERROR:%.*]] : $MyError): + // CHECK-NEXT: store [[ERROR]] to [trivial] %0 : $*MyError + // CHECK-NEXT: throw_addr + func f() throws(MyError) { } +} + +struct UntypedRes: P { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic10UntypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed UntypedRes<τ_0_0>) -> @error_indirect any Error + // CHECK: bb0(%0 : $*any Error, %1 : $*UntypedRes<τ_0_0>): + // CHECK: [[SELF:%.*]] = load [trivial] %1 : $*UntypedRes<τ_0_0> + // CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic10UntypedResV1fyyKF : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error + // CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + + // CHECK: [[NORMAL_BB]] + // CHECK: return + + // CHECK: [[ERROR_BB]]([[ERROR:%.*]] : @owned $any Error): + // CHECK-NEXT: store [[ERROR]] to [init] %0 : $*any Error + // CHECK-NEXT: throw_addr + func f() throws { } +} + +struct InfallibleRes: P { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic13InfallibleResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed InfallibleRes<τ_0_0>) -> @error_indirect any Error + // CHECK: bb0(%0 : $*any Error, %1 : $*InfallibleRes<τ_0_0>): + // CHECK: [[SELF:%.*]] = load [trivial] %1 : $*InfallibleRes<τ_0_0> + // CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic13InfallibleResV1fyyF : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>) -> () + // CHECK: = apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>) + func f() { } +} From 76a89500d6a83159ae12bef2d7c64485aeff3173 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Fri, 10 Nov 2023 15:28:36 -0800 Subject: [PATCH 2/2] [SILGen] Handle indirect errors in other thunk kinds --- lib/SILGen/SILGenBackDeploy.cpp | 6 ++++- lib/SILGen/SILGenPoly.cpp | 44 ++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/lib/SILGen/SILGenBackDeploy.cpp b/lib/SILGen/SILGenBackDeploy.cpp index b8be13ecd16cd..c2ebc79ebb52d 100644 --- a/lib/SILGen/SILGenBackDeploy.cpp +++ b/lib/SILGen/SILGenBackDeploy.cpp @@ -221,7 +221,8 @@ void SILGenFunction::emitBackDeploymentThunk(SILDeclRef thunk) { // Generate the thunk prolog by collecting parameters. SmallVector params; SmallVector indirectParams; - collectThunkParams(loc, params, &indirectParams); + SmallVector indirectErrorResults; + collectThunkParams(loc, params, &indirectParams, &indirectErrorResults); // Build up the list of arguments that we're going to invoke the the real // function with. @@ -229,6 +230,9 @@ void SILGenFunction::emitBackDeploymentThunk(SILDeclRef thunk) { for (auto indirectParam : indirectParams) { paramsForForwarding.emplace_back(indirectParam.getLValueAddress()); } + for (auto indirectErrorResult : indirectErrorResults) { + paramsForForwarding.emplace_back(indirectErrorResult.getLValueAddress()); + } for (auto param : params) { // We're going to directly call either the original function or the fallback diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 0a11d90654d41..b84e3383dc236 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -5188,7 +5188,8 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF, SmallVector params; SmallVector indirectResults; - SGF.collectThunkParams(loc, params, &indirectResults); + SmallVector indirectErrorResults; + SGF.collectThunkParams(loc, params, &indirectResults, &indirectErrorResults); // Ignore the self parameter at the SIL level. IRGen will use it to // recover type metadata. @@ -5198,13 +5199,16 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF, ManagedValue fnValue = params.pop_back_val(); auto fnType = fnValue.getType().castTo(); + // Forward indirect result arguments. SmallVector argValues; if (!indirectResults.empty()) { for (auto result : indirectResults) argValues.push_back(result.getLValueAddress()); } - // Forward indirect result arguments. + // Forward indirect error arguments. + for (auto indirectError : indirectErrorResults) + argValues.push_back(indirectError.getLValueAddress()); // Add the rest of the arguments. forwardFunctionArguments(SGF, loc, fnType, params, argValues); @@ -5389,7 +5393,9 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap( SILGenFunction thunkSGF(SGM, *thunk, FunctionDC); SmallVector params; SmallVector thunkIndirectResults; - thunkSGF.collectThunkParams(loc, params, &thunkIndirectResults); + SmallVector thunkIndirectErrorResults; + thunkSGF.collectThunkParams( + loc, params, &thunkIndirectResults, &thunkIndirectErrorResults); SILFunctionConventions fromConv(fromType, getModule()); SILFunctionConventions toConv(toType, getModule()); @@ -5397,6 +5403,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap( SmallVector thunkArguments; for (auto indRes : thunkIndirectResults) thunkArguments.push_back(indRes); + for (auto indErrRes : thunkIndirectErrorResults) + thunkArguments.push_back(indErrRes); thunkArguments.append(params.begin(), params.end()); SmallVector toParameters( toConv.getParameters().begin(), toConv.getParameters().end()); @@ -5405,7 +5413,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap( // Handle self reordering. // - For pullbacks: reorder result infos. // - For differentials: reorder parameter infos and arguments. - auto numIndirectResults = thunkIndirectResults.size(); + auto numIndirectResults = + thunkIndirectResults.size() + thunkIndirectErrorResults.size(); if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback && toResults.size() > 1) { std::rotate(toResults.begin(), toResults.end() - 1, toResults.end()); @@ -5477,6 +5486,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap( SmallVector thunkArguments; thunkArguments.append(thunkIndirectResults.begin(), thunkIndirectResults.end()); + thunkArguments.append(thunkIndirectErrorResults.begin(), + thunkIndirectErrorResults.end()); thunkArguments.append(params.begin(), params.end()); SmallVector toParameters(toConv.getParameters().begin(), toConv.getParameters().end()); @@ -5723,7 +5734,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext()); SmallVector params; SmallVector indirectResults; - thunkSGF.collectThunkParams(loc, params, &indirectResults); + SmallVector indirectErrorResults; + thunkSGF.collectThunkParams( + loc, params, &indirectResults, &indirectErrorResults); auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn); auto fnRefType = @@ -5767,6 +5780,8 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( SmallVector arguments; for (auto indRes : indirectResults) arguments.push_back(indRes.getLValueAddress()); + for (auto indErrorRes : indirectErrorResults) + arguments.push_back(indErrorRes.getLValueAddress()); forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments); // Apply function argument. @@ -6200,6 +6215,13 @@ SILGenFunction::emitVTableThunk(SILDeclRef base, inputOrigType.getFunctionResultType(), inputSubstType.getResult(), derivedFTy, thunkTy); + + // If the function we're calling has as indirect error result, create an + // argument for it. + if (auto innerIndirectErrorAddr = + emitThunkIndirectErrorArgument(*this, loc, derivedFTy)) { + args.push_back(*innerIndirectErrorAddr); + } } // Then, the arguments. @@ -6584,13 +6606,13 @@ void SILGenFunction::emitProtocolWitness( reqtOrigTy.getFunctionResultType(), reqtSubstTy.getResult(), witnessFTy, thunkTy); - } - // If the function we're calling has as indirect error result, create an - // argument for it. - if (auto innerIndirectErrorAddr = - emitThunkIndirectErrorArgument(*this, loc, witnessFTy)) { - args.push_back(*innerIndirectErrorAddr); + // If the function we're calling has as indirect error result, create an + // argument for it. + if (auto innerIndirectErrorAddr = + emitThunkIndirectErrorArgument(*this, loc, witnessFTy)) { + args.push_back(*innerIndirectErrorAddr); + } } // - the rest of the arguments