-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[flang][OpenMP] Emit requirements in module files #163449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
REQUIRES clauses apply to the compilation unit, which the OpenMP spec defines as the program unit in Fortran. Don't set REQUIRES flags on all containing scopes, only on the containng program unit, where flags coming from different directives are gathered. If we wanted to set the flags on subprograms, we would need to first accummulate all of them, then propagate them down to all subprograms. That is not done as it is not necessary (the containing program unit is always available).
For each program unit, collect the set of requirements from REQUIRES directives in the source, and modules used by the program unit, and add them to the details of the program unit symbol. The requirements in the symbol details as now stored as clauses. Since requirements need to be emitted in the module files as OpenMP directives, this makes the clause emission straightforward via getOpenMPClauseName. Each program unit, including modules, the corresponding symbol will have the transitive closure of the requirements for everything contained or used in that program unit.
@llvm/pr-subscribers-flang-semantics @llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesFor each program unit, collect the set of requirements from REQUIRES directives in the source, and modules used by the program unit, and add them to the details of the program unit symbol. The requirements in the symbol details as now stored as clauses. Since requirements need to be emitted in the module files as OpenMP directives, this makes the clause emission straightforward via getOpenMPClauseName. Each program unit, including modules, the corresponding symbol will have the transitive closure of the requirements for everything contained or used in that program unit. Full diff: https://github.com/llvm/llvm-project/pull/163449.diff 8 Files Affected:
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 77f567e69ce55..14da5b443633f 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -16,6 +16,7 @@
#include "flang/Semantics/module-dependences.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
#include <array>
#include <functional>
@@ -50,32 +51,31 @@ using MutableSymbolVector = std::vector<MutableSymbolRef>;
// Mixin for details with OpenMP declarative constructs.
class WithOmpDeclarative {
- using OmpAtomicOrderType = common::OmpMemoryOrderType;
-
public:
- ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory,
- DynamicAllocators);
- using RequiresFlags = common::EnumSet<RequiresFlag, RequiresFlag_enumSize>;
+ // The set of requirements for any program unit include requirements
+ // from any module used in the program unit.
+ using RequiresClauses =
+ common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>;
bool has_ompRequires() const { return ompRequires_.has_value(); }
- const RequiresFlags *ompRequires() const {
+ const RequiresClauses *ompRequires() const {
return ompRequires_ ? &*ompRequires_ : nullptr;
}
- void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; }
+ void set_ompRequires(RequiresClauses clauses) { ompRequires_ = clauses; }
bool has_ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_.has_value();
}
- const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const {
+ const common::OmpMemoryOrderType *ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
}
- void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) {
+ void set_ompAtomicDefaultMemOrder(common::OmpMemoryOrderType flags) {
ompAtomicDefaultMemOrder_ = flags;
}
private:
- std::optional<RequiresFlags> ompRequires_;
- std::optional<OmpAtomicOrderType> ompAtomicDefaultMemOrder_;
+ std::optional<RequiresClauses> ompRequires_;
+ std::optional<common::OmpMemoryOrderType> ompAtomicDefaultMemOrder_;
};
// A module or submodule.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 444f27471020b..f86ee01355104 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
const semantics::Symbol *symbol) {
using MlirRequires = mlir::omp::ClauseRequires;
- using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag;
if (auto offloadMod =
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
- semantics::WithOmpDeclarative::RequiresFlags semaFlags;
+ semantics::WithOmpDeclarative::RequiresClauses reqs;
if (symbol) {
common::visit(
[&](const auto &details) {
if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
std::decay_t<decltype(details)>>) {
if (details.has_ompRequires())
- semaFlags = *details.ompRequires();
+ reqs = *details.ompRequires();
}
},
symbol->details());
@@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
// Use pre-populated omp.requires module attribute if it was set, so that
// the "-fopenmp-force-usm" compiler option is honored.
MlirRequires mlirFlags = offloadMod.getRequires();
- if (semaFlags.test(SemaRequires::ReverseOffload))
+ if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators))
+ mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
+ if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload))
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
- if (semaFlags.test(SemaRequires::UnifiedAddress))
+ if (reqs.test(llvm::omp::Clause::OMPC_unified_address))
mlirFlags = mlirFlags | MlirRequires::unified_address;
- if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
+ if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory))
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
- if (semaFlags.test(SemaRequires::DynamicAllocators))
- mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
offloadMod.setRequires(mlirFlags);
}
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 8074c94b41e1a..86cc4632a5763 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -17,6 +17,7 @@
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
@@ -24,6 +25,7 @@
#include <fstream>
#include <set>
#include <string_view>
+#include <type_traits>
#include <variant>
#include <vector>
@@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
}
}
+static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) {
+ using RequiresClauses = WithOmpDeclarative::RequiresClauses;
+ using OmpMemoryOrderType = common::OmpMemoryOrderType;
+
+ const auto [reqs, order]{common::visit(
+ [&](auto &&details)
+ -> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> {
+ if constexpr (std::is_convertible_v<decltype(details),
+ const WithOmpDeclarative &>) {
+ return {details.ompRequires(), details.ompAtomicDefaultMemOrder()};
+ } else {
+ return {nullptr, nullptr};
+ }
+ },
+ symbol.details())};
+
+ if (order) {
+ llvm::omp::Clause atmo{llvm::omp::Clause::OMPC_atomic_default_mem_order};
+ os << "!$omp requires "
+ << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(atmo))
+ << '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n";
+ }
+ if (reqs) {
+ os << "!$omp requires";
+ reqs->IterateOverMembers([&](llvm::omp::Clause f) {
+ if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) {
+ os << ' '
+ << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
+ }
+ });
+ os << "\n";
+ }
+}
+
// Put out the visible symbols from scope.
void ModFileWriter::PutSymbols(
const Scope &scope, UnorderedSymbolSet *hermeticModules) {
@@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols(
for (const Symbol &symbol : uses) {
PutUse(symbol);
}
+ PutOpenMPRequirements(decls_, DEREF(scope.symbol()));
for (const auto &set : scope.equivalenceSets()) {
if (!set.empty() &&
!set.front().symbol.test(Symbol::Flag::CompilerCreated)) {
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index de680b41d1524..122849356ca39 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -435,6 +435,22 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return true;
}
+ bool Pre(const parser::UseStmt &x) {
+ if (x.moduleName.symbol) {
+ Scope &thisScope{context_.FindScope(x.moduleName.source)};
+ common::visit(
+ [&](auto &&details) {
+ if constexpr (std::is_convertible_v<decltype(details),
+ const WithOmpDeclarative &>) {
+ AddOmpRequiresToScope(thisScope, details.ompRequires(),
+ details.ompAtomicDefaultMemOrder());
+ }
+ },
+ x.moduleName.symbol->details());
+ }
+ return true;
+ }
+
bool Pre(const parser::OmpMetadirectiveDirective &x) {
PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective);
return true;
@@ -538,38 +554,37 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
void Post(const parser::OpenMPFlushConstruct &) { PopContext(); }
bool Pre(const parser::OpenMPRequiresConstruct &x) {
- using Flags = WithOmpDeclarative::RequiresFlags;
- using Requires = WithOmpDeclarative::RequiresFlag;
+ using RequiresClauses = WithOmpDeclarative::RequiresClauses;
PushContext(x.source, llvm::omp::Directive::OMPD_requires);
// Gather information from the clauses.
- Flags flags;
- std::optional<common::OmpMemoryOrderType> memOrder;
+ RequiresClauses reqs;
+ const common::OmpMemoryOrderType *memOrder{nullptr};
for (const parser::OmpClause &clause : x.v.Clauses().v) {
- flags |= common::visit(
+ using OmpClause = parser::OmpClause;
+ reqs |= common::visit(
common::visitors{
- [&memOrder](
- const parser::OmpClause::AtomicDefaultMemOrder &atomic) {
- memOrder = atomic.v.v;
- return Flags{};
- },
- [](const parser::OmpClause::ReverseOffload &) {
- return Flags{Requires::ReverseOffload};
- },
- [](const parser::OmpClause::UnifiedAddress &) {
- return Flags{Requires::UnifiedAddress};
+ [&](const OmpClause::AtomicDefaultMemOrder &atomic) {
+ memOrder = &atomic.v.v;
+ return RequiresClauses{};
},
- [](const parser::OmpClause::UnifiedSharedMemory &) {
- return Flags{Requires::UnifiedSharedMemory};
- },
- [](const parser::OmpClause::DynamicAllocators &) {
- return Flags{Requires::DynamicAllocators};
+ [&](auto &&s) {
+ using TypeS = llvm::remove_cvref_t<decltype(s)>;
+ if constexpr ( //
+ std::is_same_v<TypeS, OmpClause::DynamicAllocators> ||
+ std::is_same_v<TypeS, OmpClause::ReverseOffload> ||
+ std::is_same_v<TypeS, OmpClause::UnifiedAddress> ||
+ std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) {
+ return RequiresClauses{clause.Id()};
+ } else {
+ return RequiresClauses{};
+ }
},
- [](const auto &) { return Flags{}; }},
+ },
clause.u);
}
// Merge clauses into parents' symbols details.
- AddOmpRequiresToScope(currScope(), flags, memOrder);
+ AddOmpRequiresToScope(currScope(), &reqs, memOrder);
return true;
}
void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); }
@@ -1001,8 +1016,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
std::int64_t ordCollapseLevel{0};
- void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags,
- std::optional<common::OmpMemoryOrderType>);
+ void AddOmpRequiresToScope(Scope &,
+ const WithOmpDeclarative::RequiresClauses *,
+ const common::OmpMemoryOrderType *);
void IssueNonConformanceWarning(llvm::omp::Directive D,
parser::CharBlock source, unsigned EmitFromVersion);
@@ -3309,86 +3325,6 @@ void ResolveOmpParts(
}
}
-void ResolveOmpTopLevelParts(
- SemanticsContext &context, const parser::Program &program) {
- if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
- return;
- }
-
- // Gather REQUIRES clauses from all non-module top-level program unit symbols,
- // combine them together ensuring compatibility and apply them to all these
- // program units. Modules are skipped because their REQUIRES clauses should be
- // propagated via USE statements instead.
- WithOmpDeclarative::RequiresFlags combinedFlags;
- std::optional<common::OmpMemoryOrderType> combinedMemOrder;
-
- // Function to go through non-module top level program units and extract
- // REQUIRES information to be processed by a function-like argument.
- auto processProgramUnits{[&](auto processFn) {
- for (const parser::ProgramUnit &unit : program.v) {
- if (!std::holds_alternative<common::Indirection<parser::Module>>(
- unit.u) &&
- !std::holds_alternative<common::Indirection<parser::Submodule>>(
- unit.u) &&
- !std::holds_alternative<
- common::Indirection<parser::CompilerDirective>>(unit.u)) {
- Symbol *symbol{common::visit(
- [&context](auto &x) {
- Scope *scope = GetScope(context, x.value());
- return scope ? scope->symbol() : nullptr;
- },
- unit.u)};
- // FIXME There is no symbol defined for MainProgram units in certain
- // circumstances, so REQUIRES information has no place to be stored in
- // these cases.
- if (!symbol) {
- continue;
- }
- common::visit(
- [&](auto &details) {
- if constexpr (std::is_convertible_v<decltype(&details),
- WithOmpDeclarative *>) {
- processFn(*symbol, details);
- }
- },
- symbol->details());
- }
- }
- }};
-
- // Combine global REQUIRES information from all program units except modules
- // and submodules.
- processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) {
- if (const WithOmpDeclarative::RequiresFlags *
- flags{details.ompRequires()}) {
- combinedFlags |= *flags;
- }
- if (const common::OmpMemoryOrderType *
- memOrder{details.ompAtomicDefaultMemOrder()}) {
- if (combinedMemOrder && *combinedMemOrder != *memOrder) {
- context.Say(symbol.scope()->sourceRange(),
- "Conflicting '%s' REQUIRES clauses found in compilation "
- "unit"_err_en_US,
- parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
- llvm::omp::Clause::OMPC_atomic_default_mem_order)
- .str()));
- }
- combinedMemOrder = *memOrder;
- }
- });
-
- // Update all program units except modules and submodules with the combined
- // global REQUIRES information.
- processProgramUnits([&](Symbol &, WithOmpDeclarative &details) {
- if (combinedFlags.any()) {
- details.set_ompRequires(combinedFlags);
- }
- if (combinedMemOrder) {
- details.set_ompAtomicDefaultMemOrder(*combinedMemOrder);
- }
- });
-}
-
static bool IsSymbolThreadprivate(const Symbol &symbol) {
if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) {
return details->symbol().test(Symbol::Flag::OmpThreadprivate);
@@ -3547,23 +3483,22 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source,
}
void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
- WithOmpDeclarative::RequiresFlags flags,
- std::optional<common::OmpMemoryOrderType> memOrder) {
+ const WithOmpDeclarative::RequiresClauses *reqs,
+ const common::OmpMemoryOrderType *memOrder) {
const Scope &programUnit{omp::GetProgramUnit(scope)};
+ using RequiresClauses = WithOmpDeclarative::RequiresClauses;
+ RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}};
if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) {
common::visit(
[&](auto &details) {
- // Store clauses information into the symbol for the parent and
- // enclosing modules, programs, functions and subroutines.
if constexpr (std::is_convertible_v<decltype(&details),
WithOmpDeclarative *>) {
- if (flags.any()) {
- if (const WithOmpDeclarative::RequiresFlags *otherFlags{
- details.ompRequires()}) {
- flags |= *otherFlags;
+ if (combinedReqs.any()) {
+ if (const RequiresClauses *otherReqs{details.ompRequires()}) {
+ combinedReqs |= *otherReqs;
}
- details.set_ompRequires(flags);
+ details.set_ompRequires(combinedReqs);
}
if (memOrder) {
if (details.has_ompAtomicDefaultMemOrder() &&
diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h
index 5a890c26aa334..36d3ce988b1b1 100644
--- a/flang/lib/Semantics/resolve-directives.h
+++ b/flang/lib/Semantics/resolve-directives.h
@@ -23,7 +23,5 @@ class SemanticsContext;
void ResolveAccParts(
SemanticsContext &, const parser::ProgramUnit &, Scope *topScope);
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
-void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
-
} // namespace Fortran::semantics
#endif
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 861218809c0f9..ae0ff9ca8068d 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) {
CHECK(!attrs_);
CHECK(!cudaDataAttr_);
CHECK(!GetDeclTypeSpec());
- // Top-level resolution to propagate information across program units after
- // each of them has been resolved separately.
- ResolveOmpTopLevelParts(context(), x);
}
// A singleton instance of the scope -> IMPLICIT rules mapping is
diff --git a/flang/test/Semantics/OpenMP/requires-modfile.f90 b/flang/test/Semantics/OpenMP/requires-modfile.f90
new file mode 100644
index 0000000000000..2f06104e208ef
--- /dev/null
+++ b/flang/test/Semantics/OpenMP/requires-modfile.f90
@@ -0,0 +1,39 @@
+!RUN: %python %S/../test_modfile.py %s %flang_fc1 -fopenmp -fopenmp-version=52
+
+module req
+contains
+
+! The requirements from the subprograms should be added to the module.
+subroutine f00
+ !$omp requires reverse_offload
+end
+
+subroutine f01
+ !$omp requires atomic_default_mem_order(seq_cst)
+end
+end module
+
+module user
+! The requirements from module req should be propagated to this module.
+use req
+end module
+
+
+!Expect: req.mod
+!module req
+!!$omp requires atomic_default_mem_order(seq_cst)
+!!$omp requires reverse_offload
+!contains
+!subroutine f00()
+!end
+!subroutine f01()
+!end
+!end
+
+!Expect: user.mod
+!module user
+!use req,only:f00
+!use req,only:f01
+!!$omp requires atomic_default_mem_order(seq_cst)
+!!$omp requires reverse_offload
+!end
diff --git a/flang/test/Semantics/OpenMP/requires09.f90 b/flang/test/Semantics/OpenMP/requires09.f90
index 2fa5d950b9c2d..ca6ad5e8b7b8a 100644
--- a/flang/test/Semantics/OpenMP/requires09.f90
+++ b/flang/test/Semantics/OpenMP/requires09.f90
@@ -3,12 +3,16 @@
! 2.4 Requires directive
! All atomic_default_mem_order clauses in 'requires' directives found within a
! compilation unit must specify the same ordering.
+!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit
+module m
+contains
subroutine f
!$omp requires atomic_default_mem_order(seq_cst)
end subroutine f
-!ERROR: Conflicting 'ATOMIC_DEFAULT_MEM_ORDER' REQUIRES clauses found in compilation unit
subroutine g
!$omp requires atomic_default_mem_order(relaxed)
end subroutine g
+
+end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
For each program unit, collect the set of requirements from REQUIRES directives in the source, and modules used by the program unit, and add them to the details of the program unit symbol.
The requirements in the symbol details as now stored as clauses. Since requirements need to be emitted in the module files as OpenMP directives, this makes the clause emission straightforward via getOpenMPClauseName.
Each program unit, including modules, the corresponding symbol will have the transitive closure of the requirements for everything contained or used in that program unit.