diff --git a/lib/AST/RequirementMachine/RequirementLowering.cpp b/lib/AST/RequirementMachine/RequirementLowering.cpp index 198f9714c013f..6f075b541ba75 100644 --- a/lib/AST/RequirementMachine/RequirementLowering.cpp +++ b/lib/AST/RequirementMachine/RequirementLowering.cpp @@ -562,15 +562,17 @@ struct InferRequirementsWalker : public TypeWalker { // - `@differentiable(_linear)`: add // `T: Differentiable`, `T == T.TangentVector` requirements. if (auto *fnTy = ty->getAs()) { + // Add a new conformance constraint for a fixed protocol. + auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) { + Requirement req(RequirementKind::Conformance, type, + protocol->getDeclaredInterfaceType()); + desugarRequirement(req, SourceLoc(), reqs, errors); + }; + auto &ctx = module->getASTContext(); auto *differentiableProtocol = ctx.getProtocol(KnownProtocolKind::Differentiable); if (differentiableProtocol && fnTy->isDifferentiable()) { - auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) { - Requirement req(RequirementKind::Conformance, type, - protocol->getDeclaredInterfaceType()); - desugarRequirement(req, SourceLoc(), reqs, errors); - }; auto addSameTypeConstraint = [&](Type firstType, AssociatedTypeDecl *assocType) { auto secondType = assocType->getDeclaredInterfaceType() @@ -596,6 +598,13 @@ struct InferRequirementsWalker : public TypeWalker { constrainParametersAndResult(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); } + + // Infer that the thrown error type conforms to Error. + if (auto thrownError = fnTy->getThrownError()) { + if (auto errorProtocol = ctx.getErrorDecl()) { + addConformanceConstraint(thrownError, errorProtocol); + } + } } if (!ty->isSpecialized()) diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp index 1a65cb3fe9ba4..3ab084b813c63 100644 --- a/lib/Sema/TypeCheckGeneric.cpp +++ b/lib/Sema/TypeCheckGeneric.cpp @@ -753,6 +753,28 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, inferenceSources.emplace_back(typeRepr, type); } + // Handle the thrown error type. + auto effectiveFunc = func ? func + : subscr ? subscr->getEffectfulGetAccessor() + : nullptr; + if (effectiveFunc) { + if (auto thrownTypeRepr = effectiveFunc->getThrownTypeRepr()) { + auto thrownOptions = baseOptions | TypeResolutionFlags::Direct; + const auto thrownType = resolution.withOptions(thrownOptions) + .resolveType(thrownTypeRepr); + + // Add this type as an inference source. + inferenceSources.emplace_back(thrownTypeRepr, thrownType); + + // Add conformance of this type to the Error protocol. + if (auto errorProtocol = ctx.getErrorDecl()) { + extraReqs.push_back( + Requirement(RequirementKind::Conformance, thrownType, + errorProtocol->getDeclaredInterfaceType())); + } + } + } + // Gather requirements from the result type. auto *resultTypeRepr = [&subscr, &func, ¯o]() -> TypeRepr * { if (subscr) { diff --git a/test/decl/func/typed_throws.swift b/test/decl/func/typed_throws.swift index 545a43ec3bcad..a3784c4368067 100644 --- a/test/decl/func/typed_throws.swift +++ b/test/decl/func/typed_throws.swift @@ -33,7 +33,6 @@ func testThrownMyErrorType() { func throwsGeneric(errorType: T.Type) throws(T) { } func throwsBadGeneric(errorType: T.Type) throws(T) { } -// expected-error@-1{{thrown type 'T' does not conform to the 'Error' protocol}} func throwsUnusedInSignature() throws(T) { } // expected-error@-1{{generic parameter 'T' is not used in function signature}} @@ -103,3 +102,24 @@ func testMapArray(numbers: [Int]) { let _: Int = error // expected-error{{cannot convert value of type 'MyError' to specified type 'Int'}} } } + +// Inference of Error conformance from the use of a generic parameter in typed +// throws. +func requiresError(_: E.Type) { } + +func infersThrowing(_ error: E.Type) throws(E) { + requiresError(error) +} + +func infersThrowingNested(_ body: () throws(E) -> Void) { + requiresError(E.self) +} + +struct HasASubscript { + subscript(_: E.Type) -> Int { + get throws(E) { + requiresError(E.self) + return 0 + } + } +}