@@ -4887,18 +4887,19 @@ void TypeChecker::checkConformancesInContext(DeclContext *dc,
4887
4887
4888
4888
llvm::TinyPtrVector<ValueDecl *>
4889
4889
TypeChecker::findWitnessedObjCRequirements (const ValueDecl *witness,
4890
- bool onlyFirstRequirement ) {
4890
+ bool anySingleRequirement ) {
4891
4891
llvm::TinyPtrVector<ValueDecl *> result;
4892
4892
4893
4893
// Types don't infer @objc this way.
4894
4894
if (isa<TypeDecl>(witness)) return result;
4895
4895
4896
4896
auto dc = witness->getDeclContext ();
4897
4897
auto name = witness->getFullName ();
4898
- for (auto conformance : dc->getLocalConformances (ConformanceLookupKind::All,
4899
- nullptr , /* sorted=*/ true )) {
4898
+ auto nominal = dc->getAsNominalTypeOrNominalTypeExtensionContext ();
4899
+ if (!nominal) return result;
4900
+
4901
+ for (auto proto : nominal->getAllProtocols ()) {
4900
4902
// We only care about Objective-C protocols.
4901
- auto proto = conformance->getProtocol ();
4902
4903
if (!proto->isObjC ()) continue ;
4903
4904
4904
4905
for (auto req : proto->lookupDirect (name, true )) {
@@ -4907,16 +4908,41 @@ TypeChecker::findWitnessedObjCRequirements(const ValueDecl *witness,
4907
4908
4908
4909
// Skip types.
4909
4910
if (isa<TypeDecl>(req)) continue ;
4911
+
4912
+ // Dig out the conformance.
4913
+ Optional<ProtocolConformance *> conformance;
4914
+ if (!conformance.hasValue ()) {
4915
+ SmallVector<ProtocolConformance *, 2 > conformances;
4916
+ nominal->lookupConformance (dc->getParentModule (), proto,
4917
+ conformances);
4918
+ if (conformances.size () == 1 )
4919
+ conformance = conformances.front ();
4920
+ else
4921
+ conformance = nullptr ;
4922
+ }
4923
+ if (!*conformance) continue ;
4910
4924
4911
4925
// Determine whether the witness for this conformance is in fact
4912
4926
// our witness.
4913
- if (conformance->getWitness (req, this ).getDecl () == witness) {
4927
+ if ((* conformance) ->getWitness (req, this ).getDecl () == witness) {
4914
4928
result.push_back (req);
4915
- if (onlyFirstRequirement ) return result;
4929
+ if (anySingleRequirement ) return result;
4916
4930
}
4917
4931
}
4918
4932
}
4919
4933
4934
+ // Sort the results.
4935
+ if (result.size () > 2 ) {
4936
+ std::stable_sort (result.begin (), result.end (),
4937
+ [&](ValueDecl *lhs, ValueDecl *rhs) {
4938
+ ProtocolDecl *lhsProto
4939
+ = cast<ProtocolDecl>(lhs->getDeclContext ());
4940
+ ProtocolDecl *rhsProto
4941
+ = cast<ProtocolDecl>(rhs->getDeclContext ());
4942
+ return ProtocolType::compareProtocols (&lhsProto,
4943
+ &rhsProto) < 0 ;
4944
+ });
4945
+ }
4920
4946
return result;
4921
4947
}
4922
4948
0 commit comments