@@ -129,7 +129,11 @@ struct UnboundImport {
129129
130130private:
131131 void validatePrivate (ModuleDecl *topLevelModule);
132- void validateImplementationOnly (ASTContext &ctx);
132+
133+ // / Check that no import has more than one of the following modifiers:
134+ // / @_exported, @_implementationOnly, and @_spiOnly.
135+ void validateRestrictedImport (ASTContext &ctx);
136+
133137 void validateTestable (ModuleDecl *topLevelModule);
134138 void validateResilience (NullablePtr<ModuleDecl> topLevelModule,
135139 SourceFile &SF);
@@ -601,7 +605,7 @@ bool UnboundImport::checkModuleLoaded(ModuleDecl *M, SourceFile &SF) {
601605
602606void UnboundImport::validateOptions (NullablePtr<ModuleDecl> topLevelModule,
603607 SourceFile &SF) {
604- validateImplementationOnly (SF.getASTContext ());
608+ validateRestrictedImport (SF.getASTContext ());
605609
606610 if (auto *top = topLevelModule.getPtrOrNull ()) {
607611 // FIXME: Having these two calls in this if condition seems dubious.
@@ -636,16 +640,62 @@ void UnboundImport::validatePrivate(ModuleDecl *topLevelModule) {
636640 import .sourceFileArg = StringRef ();
637641}
638642
639- void UnboundImport::validateImplementationOnly (ASTContext &ctx) {
640- if (!import .options .contains (ImportFlags::ImplementationOnly) ||
641- !import .options .contains (ImportFlags::Exported))
643+ void UnboundImport::validateRestrictedImport (ASTContext &ctx) {
644+ static llvm::SmallVector<ImportFlags, 2 > flags = {ImportFlags::Exported,
645+ ImportFlags::ImplementationOnly,
646+ ImportFlags::SPIOnly};
647+ llvm::SmallVector<ImportFlags, 2 > conflicts;
648+
649+ for (auto flag : flags) {
650+ if (import .options .contains (flag))
651+ conflicts.push_back (flag);
652+ }
653+
654+ // Quit if there's no conflicting attributes.
655+ if (conflicts.size () < 2 )
642656 return ;
643657
644- // Remove one flag to maintain the invariant.
645- import .options -= ImportFlags::ImplementationOnly;
658+ // Remove all but one flag to maintain the invariant.
659+ for (auto iter = conflicts.begin (); iter != std::prev (conflicts.end ()); iter ++)
660+ import .options -= *iter;
646661
647- diagnoseInvalidAttr (DAK_ImplementationOnly, ctx.Diags ,
648- diag::import_implementation_cannot_be_exported);
662+ DeclAttrKind attrToRemove = conflicts[0 ] == ImportFlags::ImplementationOnly?
663+ DAK_Exported : DAK_ImplementationOnly;
664+
665+ // More dense enum with some cases of ImportFlags,
666+ // used by import_restriction_conflict.
667+ enum class ImportFlagForDiag : uint8_t {
668+ ImplementationOnly,
669+ SPIOnly,
670+ Exported
671+ };
672+ auto flagToDiag = [](ImportFlags flag) {
673+ switch (flag) {
674+ case ImportFlags::ImplementationOnly:
675+ return ImportFlagForDiag::ImplementationOnly;
676+ case ImportFlags::SPIOnly:
677+ return ImportFlagForDiag::SPIOnly;
678+ case ImportFlags::Exported:
679+ return ImportFlagForDiag::Exported;
680+ default :
681+ llvm_unreachable (" Unexpected ImportFlag" );
682+ }
683+ };
684+
685+ // Report the conflict, only the first two conflicts should be enough.
686+ auto diag = ctx.Diags .diagnose (import .module .getModulePath ().front ().Loc ,
687+ diag::import_restriction_conflict,
688+ import .module .getModulePath ().front ().Item ,
689+ (uint8_t )flagToDiag (conflicts[0 ]),
690+ (uint8_t )flagToDiag (conflicts[1 ]));
691+
692+ auto *ID = getImportDecl ().getPtrOrNull ();
693+ if (!ID) return ;
694+ auto *attr = ID->getAttrs ().getAttribute (attrToRemove);
695+ if (!attr) return ;
696+
697+ diag.fixItRemove (attr->getRangeWithAt ());
698+ attr->setInvalid ();
649699}
650700
651701void UnboundImport::validateTestable (ModuleDecl *topLevelModule) {
@@ -709,54 +759,66 @@ static bool moduleHasAnyImportsMatchingFlag(ModuleDecl *mod, ImportFlags flag) {
709759 return false ;
710760}
711761
712- // / Finds all import declarations for a single module that inconsistently match
762+ // / Finds all import declarations for a single file that inconsistently match
713763// / \c predicate and passes each pair of inconsistent imports to \c diagnose.
714764template <typename Pred, typename Diag>
715- static void findInconsistentImports (ModuleDecl *mod, Pred predicate,
716- Diag diagnose) {
717- llvm::DenseMap<ModuleDecl *, const ImportDecl *> matchingImports;
718- llvm::DenseMap<ModuleDecl *, std::vector<const ImportDecl *>> otherImports;
765+ static void findInconsistentImportsAcrossFile (
766+ const SourceFile *SF, Pred predicate, Diag diagnose,
767+ llvm::DenseMap<ModuleDecl *, const ImportDecl *> &matchingImports,
768+ llvm::DenseMap<ModuleDecl *, std::vector<const ImportDecl *>> &otherImports) {
769+
770+ for (auto *topLevelDecl : SF->getTopLevelDecls ()) {
771+ auto *nextImport = dyn_cast<ImportDecl>(topLevelDecl);
772+ if (!nextImport)
773+ continue ;
719774
720- for (const FileUnit *file : mod->getFiles ()) {
721- auto *SF = dyn_cast<SourceFile>(file);
722- if (!SF)
775+ ModuleDecl *module = nextImport->getModule ();
776+ if (!module )
723777 continue ;
724778
725- for (auto *topLevelDecl : SF->getTopLevelDecls ()) {
726- auto *nextImport = dyn_cast<ImportDecl>(topLevelDecl);
727- if (!nextImport)
779+ if (predicate (nextImport)) {
780+ // We found a matching import.
781+ bool isNew = matchingImports.insert ({module , nextImport}).second ;
782+ if (!isNew)
728783 continue ;
729784
730- ModuleDecl *module = nextImport->getModule ();
731- if (!module )
732- continue ;
785+ auto seenOtherImportPosition = otherImports.find (module );
786+ if (seenOtherImportPosition != otherImports.end ()) {
787+ for (auto *seenOtherImport : seenOtherImportPosition->getSecond ())
788+ diagnose (seenOtherImport, nextImport);
733789
734- if ( predicate (nextImport)) {
735- // We found a matching import.
736- bool isNew = matchingImports. insert ({ module , nextImport}). second ;
737- if (!isNew)
738- continue ;
790+ // We're done with these; keep the map small if possible.
791+ otherImports. erase (seenOtherImportPosition);
792+ }
793+ continue ;
794+ }
739795
740- auto seenOtherImportPosition = otherImports.find (module );
741- if (seenOtherImportPosition != otherImports.end ()) {
742- for (auto *seenOtherImport : seenOtherImportPosition->getSecond ())
743- diagnose (seenOtherImport, nextImport);
796+ // We saw a non-matching import. Is that in conflict with what we've seen?
797+ if (auto *seenMatchingImport = matchingImports.lookup (module )) {
798+ diagnose (nextImport, seenMatchingImport);
799+ continue ;
800+ }
744801
745- // We're done with these; keep the map small if possible.
746- otherImports.erase (seenOtherImportPosition);
747- }
748- continue ;
749- }
802+ // Otherwise, record it for later.
803+ otherImports[module ].push_back (nextImport);
804+ }
805+ }
750806
751- // We saw a non-matching import. Is that in conflict with what we've seen?
752- if (auto *seenMatchingImport = matchingImports.lookup (module )) {
753- diagnose (nextImport, seenMatchingImport);
754- continue ;
755- }
807+ // / Finds all import declarations for a single module that inconsistently match
808+ // / \c predicate and passes each pair of inconsistent imports to \c diagnose.
809+ template <typename Pred, typename Diag>
810+ static void findInconsistentImportsAcrossModule (ModuleDecl *mod, Pred predicate,
811+ Diag diagnose) {
812+ llvm::DenseMap<ModuleDecl *, const ImportDecl *> matchingImports;
813+ llvm::DenseMap<ModuleDecl *, std::vector<const ImportDecl *>> otherImports;
756814
757- // Otherwise, record it for later.
758- otherImports[module ].push_back (nextImport);
759- }
815+ for (const FileUnit *file : mod->getFiles ()) {
816+ auto *SF = dyn_cast<SourceFile>(file);
817+ if (!SF)
818+ continue ;
819+
820+ findInconsistentImportsAcrossFile (SF, predicate, diagnose,
821+ matchingImports, otherImports);
760822 }
761823}
762824
@@ -790,7 +852,34 @@ CheckInconsistentImplementationOnlyImportsRequest::evaluate(
790852 return decl->getAttrs ().hasAttribute <ImplementationOnlyAttr>();
791853 };
792854
793- findInconsistentImports (mod, predicate, diagnose);
855+ findInconsistentImportsAcrossModule (mod, predicate, diagnose);
856+ return {};
857+ }
858+
859+ evaluator::SideEffect
860+ CheckInconsistentSPIOnlyImportsRequest::evaluate (
861+ Evaluator &evaluator, SourceFile *SF) const {
862+
863+ auto mod = SF->getParentModule ();
864+ auto diagnose = [mod](const ImportDecl *normalImport,
865+ const ImportDecl *spiOnlyImport) {
866+ auto &diags = mod->getDiags ();
867+ {
868+ diags.diagnose (normalImport, diag::spi_only_import_conflict,
869+ normalImport->getModule ()->getName ());
870+ }
871+ diags.diagnose (spiOnlyImport,
872+ diag::spi_only_import_conflict_here);
873+ };
874+
875+ auto predicate = [](ImportDecl *decl) {
876+ return decl->getAttrs ().hasAttribute <SPIOnlyAttr>();
877+ };
878+
879+ llvm::DenseMap<ModuleDecl *, const ImportDecl *> matchingImports;
880+ llvm::DenseMap<ModuleDecl *, std::vector<const ImportDecl *>> otherImports;
881+ findInconsistentImportsAcrossFile (SF, predicate, diagnose,
882+ matchingImports, otherImports);
794883 return {};
795884}
796885
@@ -815,7 +904,7 @@ CheckInconsistentWeakLinkedImportsRequest::evaluate(Evaluator &evaluator,
815904 return decl->getAttrs ().hasAttribute <WeakLinkedAttr>();
816905 };
817906
818- findInconsistentImports (mod, predicate, diagnose);
907+ findInconsistentImportsAcrossModule (mod, predicate, diagnose);
819908 return {};
820909}
821910
0 commit comments