@@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
31723172 return false ;
31733173}
31743174
3175+ class AddEquatableContext {
3176+
3177+ // / Declaration context
3178+ DeclContext *DC;
3179+
3180+ // / Adopter type
3181+ Type Adopter;
3182+
3183+ // / Start location of declaration context brace
3184+ SourceLoc StartLoc;
3185+
3186+ // / Array of all inherited protocols' locations
3187+ ArrayRef<TypeLoc> ProtocolsLocations;
3188+
3189+ // / Array of all conformed protocols
3190+ SmallVector<swift::ProtocolDecl *, 2 > Protocols;
3191+
3192+ // / Start location of declaration,
3193+ // / a place to write protocol name
3194+ SourceLoc ProtInsertStartLoc;
3195+
3196+ // / Stored properties of extending adopter
3197+ ArrayRef<VarDecl *> StoredProperties;
3198+
3199+ // / Range of internal members in declaration
3200+ DeclRange Range;
3201+
3202+ bool conformsToEquatableProtocol () {
3203+ for (ProtocolDecl *Protocol : Protocols) {
3204+ if (Protocol->getKnownProtocolKind () == KnownProtocolKind::Equatable) {
3205+ return true ;
3206+ }
3207+ }
3208+ return false ;
3209+ }
3210+
3211+ bool isRequirementValid () {
3212+ auto Reqs = getProtocolRequirements ();
3213+ if (Reqs.empty ()) {
3214+ return false ;
3215+ }
3216+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3217+ return Req && Req->getParameters ()->size () == 2 ;
3218+ }
3219+
3220+ bool isPropertiesListValid () {
3221+ return !getUserAccessibleProperties ().empty ();
3222+ }
3223+
3224+ void printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent,
3225+ ParameterList *Params);
3226+
3227+ std::vector<ValueDecl *> getProtocolRequirements ();
3228+
3229+ std::vector<VarDecl *> getUserAccessibleProperties ();
3230+
3231+ public:
3232+
3233+ AddEquatableContext (NominalTypeDecl *Decl) : DC(Decl),
3234+ Adopter (Decl->getDeclaredType ()), StartLoc(Decl->getBraces ().Start),
3235+ ProtocolsLocations(Decl->getInherited ()),
3236+ Protocols(Decl->getAllProtocols ()), ProtInsertStartLoc(Decl->getNameLoc ()),
3237+ StoredProperties(Decl->getStoredProperties ()), Range(Decl->getMembers ()) {};
3238+
3239+ AddEquatableContext (ExtensionDecl *Decl) : DC(Decl),
3240+ Adopter(Decl->getExtendedType ()), StartLoc(Decl->getBraces ().Start),
3241+ ProtocolsLocations(Decl->getInherited ()),
3242+ Protocols(Decl->getExtendedNominal ()->getAllProtocols()),
3243+ ProtInsertStartLoc(Decl->getExtendedTypeRepr ()->getEndLoc()),
3244+ StoredProperties(Decl->getExtendedNominal ()->getStoredProperties()), Range(Decl->getMembers ()) {};
3245+
3246+ AddEquatableContext () : DC(nullptr ), Adopter(), ProtocolsLocations(),
3247+ Protocols(), StoredProperties(), Range(nullptr , nullptr ) {};
3248+
3249+ static AddEquatableContext getDeclarationContextFromInfo (ResolvedCursorInfo Info);
3250+
3251+ std::string getInsertionTextForProtocol ();
3252+
3253+ std::string getInsertionTextForFunction (SourceManager &SM);
3254+
3255+ bool isValid () {
3256+ // FIXME: Allow to generate explicit == method for declarations which already have
3257+ // compiler-generated == method
3258+ return StartLoc.isValid () && ProtInsertStartLoc.isValid () &&
3259+ !conformsToEquatableProtocol () && isPropertiesListValid () &&
3260+ isRequirementValid ();
3261+ }
3262+
3263+ SourceLoc getStartLocForProtocolDecl () {
3264+ if (ProtocolsLocations.empty ()) {
3265+ return ProtInsertStartLoc;
3266+ }
3267+ return ProtocolsLocations.back ().getSourceRange ().Start ;
3268+ }
3269+
3270+ bool isMembersRangeEmpty () {
3271+ return Range.empty ();
3272+ }
3273+
3274+ SourceLoc getInsertStartLoc ();
3275+ };
3276+
3277+ SourceLoc AddEquatableContext::
3278+ getInsertStartLoc () {
3279+ SourceLoc MaxLoc = StartLoc;
3280+ for (auto Mem : Range) {
3281+ if (Mem->getEndLoc ().getOpaquePointerValue () >
3282+ MaxLoc.getOpaquePointerValue ()) {
3283+ MaxLoc = Mem->getEndLoc ();
3284+ }
3285+ }
3286+ return MaxLoc;
3287+ }
3288+
3289+ std::string AddEquatableContext::
3290+ getInsertionTextForProtocol () {
3291+ StringRef ProtocolName = getProtocolName (KnownProtocolKind::Equatable);
3292+ std::string Buffer;
3293+ llvm::raw_string_ostream OS (Buffer);
3294+ if (ProtocolsLocations.empty ()) {
3295+ OS << " : " << ProtocolName;
3296+ return Buffer;
3297+ }
3298+ OS << " , " << ProtocolName;
3299+ return Buffer;
3300+ }
3301+
3302+ std::string AddEquatableContext::
3303+ getInsertionTextForFunction (SourceManager &SM) {
3304+ auto Reqs = getProtocolRequirements ();
3305+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3306+ auto Params = Req->getParameters ();
3307+ StringRef ExtraIndent;
3308+ StringRef CurrentIndent =
3309+ Lexer::getIndentationForLine (SM, getInsertStartLoc (), &ExtraIndent);
3310+ std::string Indent;
3311+ if (isMembersRangeEmpty ()) {
3312+ Indent = (CurrentIndent + ExtraIndent).str ();
3313+ } else {
3314+ Indent = CurrentIndent.str ();
3315+ }
3316+ PrintOptions Options = PrintOptions::printVerbose ();
3317+ Options.PrintDocumentationComments = false ;
3318+ Options.setBaseType (Adopter);
3319+ Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
3320+ Printer << " {" ;
3321+ Printer.printNewline ();
3322+ printFunctionBody (Printer, ExtraIndent, Params);
3323+ Printer.printNewline ();
3324+ Printer << " }" ;
3325+ };
3326+ std::string Buffer;
3327+ llvm::raw_string_ostream OS (Buffer);
3328+ ExtraIndentStreamPrinter Printer (OS, Indent);
3329+ Printer.printNewline ();
3330+ if (!isMembersRangeEmpty ()) {
3331+ Printer.printNewline ();
3332+ }
3333+ Reqs[0 ]->print (Printer, Options);
3334+ return Buffer;
3335+ }
3336+
3337+ std::vector<VarDecl *> AddEquatableContext::
3338+ getUserAccessibleProperties () {
3339+ std::vector<VarDecl *> PublicProperties;
3340+ for (VarDecl *Decl : StoredProperties) {
3341+ if (Decl->Decl ::isUserAccessible ()) {
3342+ PublicProperties.push_back (Decl);
3343+ }
3344+ }
3345+ return PublicProperties;
3346+ }
3347+
3348+ std::vector<ValueDecl *> AddEquatableContext::
3349+ getProtocolRequirements () {
3350+ std::vector<ValueDecl *> Collection;
3351+ auto Proto = DC->getASTContext ().getProtocol (KnownProtocolKind::Equatable);
3352+ for (auto Member : Proto->getMembers ()) {
3353+ auto Req = dyn_cast<ValueDecl>(Member);
3354+ if (!Req || Req->isInvalid () || !Req->isProtocolRequirement ()) {
3355+ continue ;
3356+ }
3357+ Collection.push_back (Req);
3358+ }
3359+ return Collection;
3360+ }
3361+
3362+ AddEquatableContext AddEquatableContext::
3363+ getDeclarationContextFromInfo (ResolvedCursorInfo Info) {
3364+ if (Info.isInvalid ()) {
3365+ return AddEquatableContext ();
3366+ }
3367+ if (!Info.IsRef ) {
3368+ if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD )) {
3369+ return AddEquatableContext (NomDecl);
3370+ }
3371+ } else if (auto *ExtDecl = Info.ExtTyRef ) {
3372+ return AddEquatableContext (ExtDecl);
3373+ }
3374+ return AddEquatableContext ();
3375+ }
3376+
3377+ void AddEquatableContext::
3378+ printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
3379+ llvm::SmallString<128 > Return;
3380+ llvm::raw_svector_ostream SS (Return);
3381+ SS << tok::kw_return;
3382+ StringRef Space = " " ;
3383+ StringRef AdditionalSpace = " " ;
3384+ StringRef Point = " ." ;
3385+ StringRef Join = " == " ;
3386+ StringRef And = " &&" ;
3387+ auto Props = getUserAccessibleProperties ();
3388+ auto FParam = Params->get (0 )->getName ();
3389+ auto SParam = Params->get (1 )->getName ();
3390+ auto Prop = Props[0 ]->getName ();
3391+ Printer << ExtraIndent << Return << Space
3392+ << FParam << Point << Prop << Join << SParam << Point << Prop;
3393+ if (Props.size () > 1 ) {
3394+ std::for_each (Props.begin () + 1 , Props.end (), [&](VarDecl *VD){
3395+ auto Name = VD->getName ();
3396+ Printer << And;
3397+ Printer.printNewline ();
3398+ Printer << ExtraIndent << AdditionalSpace << FParam << Point
3399+ << Name << Join << SParam << Point << Name;
3400+ });
3401+ }
3402+ }
3403+
3404+ bool RefactoringActionAddEquatableConformance::
3405+ isApplicable (ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
3406+ return AddEquatableContext::getDeclarationContextFromInfo (Tok).isValid ();
3407+ }
3408+
3409+ bool RefactoringActionAddEquatableConformance::
3410+ performChange () {
3411+ auto Context = AddEquatableContext::getDeclarationContextFromInfo (CursorInfo);
3412+ EditConsumer.insertAfter (SM, Context.getStartLocForProtocolDecl (),
3413+ Context.getInsertionTextForProtocol ());
3414+ EditConsumer.insertAfter (SM, Context.getInsertStartLoc (),
3415+ Context.getInsertionTextForFunction (SM));
3416+ return false ;
3417+ }
3418+
31753419static CharSourceRange
31763420 findSourceRangeToWrapInCatch (ResolvedCursorInfo CursorInfo,
31773421 SourceFile *TheFile,
0 commit comments