Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "flang/Semantics/module-dependences.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/Frontend/OpenMP/OMP.h"

#include <array>
#include <functional>
Expand Down Expand Up @@ -50,32 +51,31 @@ using MutableSymbolVector = std::vector<MutableSymbolRef>;

// Mixin for details with OpenMP declarative constructs.
class WithOmpDeclarative {
using OmpAtomicOrderType = common::OmpMemoryOrderType;

public:
ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory,
DynamicAllocators);
using RequiresFlags = common::EnumSet<RequiresFlag, RequiresFlag_enumSize>;
// The set of requirements for any program unit include requirements
// from any module used in the program unit.
using RequiresClauses =
common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>;

bool has_ompRequires() const { return ompRequires_.has_value(); }
const RequiresFlags *ompRequires() const {
const RequiresClauses *ompRequires() const {
return ompRequires_ ? &*ompRequires_ : nullptr;
}
void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; }
void set_ompRequires(RequiresClauses clauses) { ompRequires_ = clauses; }

bool has_ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_.has_value();
}
const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const {
const common::OmpMemoryOrderType *ompAtomicDefaultMemOrder() const {
return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
}
void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) {
void set_ompAtomicDefaultMemOrder(common::OmpMemoryOrderType flags) {
ompAtomicDefaultMemOrder_ = flags;
}

private:
std::optional<RequiresFlags> ompRequires_;
std::optional<OmpAtomicOrderType> ompAtomicDefaultMemOrder_;
std::optional<RequiresClauses> ompRequires_;
std::optional<common::OmpMemoryOrderType> ompAtomicDefaultMemOrder_;
};

// A module or submodule.
Expand Down
15 changes: 7 additions & 8 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
const semantics::Symbol *symbol) {
using MlirRequires = mlir::omp::ClauseRequires;
using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag;

if (auto offloadMod =
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
semantics::WithOmpDeclarative::RequiresFlags semaFlags;
semantics::WithOmpDeclarative::RequiresClauses reqs;
if (symbol) {
common::visit(
[&](const auto &details) {
if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
std::decay_t<decltype(details)>>) {
if (details.has_ompRequires())
semaFlags = *details.ompRequires();
reqs = *details.ompRequires();
}
},
symbol->details());
Expand All @@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
// Use pre-populated omp.requires module attribute if it was set, so that
// the "-fopenmp-force-usm" compiler option is honored.
MlirRequires mlirFlags = offloadMod.getRequires();
if (semaFlags.test(SemaRequires::ReverseOffload))
if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators))
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload))
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
if (semaFlags.test(SemaRequires::UnifiedAddress))
if (reqs.test(llvm::omp::Clause::OMPC_unified_address))
mlirFlags = mlirFlags | MlirRequires::unified_address;
if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory))
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
if (semaFlags.test(SemaRequires::DynamicAllocators))
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;

offloadMod.setRequires(mlirFlags);
}
Expand Down
37 changes: 37 additions & 0 deletions flang/lib/Semantics/mod-file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <fstream>
#include <set>
#include <string_view>
#include <type_traits>
#include <variant>
#include <vector>

Expand Down Expand Up @@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
}
}

static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) {
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
using OmpMemoryOrderType = common::OmpMemoryOrderType;

const auto [reqs, order]{common::visit(
[&](auto &&details)
-> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> {
if constexpr (std::is_convertible_v<decltype(details),
const WithOmpDeclarative &>) {
return {details.ompRequires(), details.ompAtomicDefaultMemOrder()};
} else {
return {nullptr, nullptr};
}
},
symbol.details())};

if (order) {
llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order};
os << "!$omp requires "
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo))
<< '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n";
}
if (reqs) {
os << "!$omp requires";
reqs->IterateOverMembers([&](llvm::omp::Clause f) {
if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) {
os << ' '
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
}
});
os << "\n";
}
}

// Put out the visible symbols from scope.
void ModFileWriter::PutSymbols(
const Scope &scope, UnorderedSymbolSet *hermeticModules) {
Expand Down Expand Up @@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols(
for (const Symbol &symbol : uses) {
PutUse(symbol);
}
PutOpenMPRequirements(decls_, DEREF(scope.symbol()));
for (const auto &set : scope.equivalenceSets()) {
if (!set.empty() &&
!set.front().symbol.test(Symbol::Flag::CompilerCreated)) {
Expand Down
161 changes: 48 additions & 113 deletions flang/lib/Semantics/resolve-directives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,22 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return true;
}

bool Pre(const parser::UseStmt &x) {
if (x.moduleName.symbol) {
Scope &thisScope{context_.FindScope(x.moduleName.source)};
common::visit(
[&](auto &&details) {
if constexpr (std::is_convertible_v<decltype(details),
const WithOmpDeclarative &>) {
AddOmpRequiresToScope(thisScope, details.ompRequires(),
details.ompAtomicDefaultMemOrder());
}
},
x.moduleName.symbol->details());
}
return true;
}

bool Pre(const parser::OmpMetadirectiveDirective &x) {
PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective);
return true;
Expand Down Expand Up @@ -538,38 +554,37 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
void Post(const parser::OpenMPFlushConstruct &) { PopContext(); }

bool Pre(const parser::OpenMPRequiresConstruct &x) {
using Flags = WithOmpDeclarative::RequiresFlags;
using Requires = WithOmpDeclarative::RequiresFlag;
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
PushContext(x.source, llvm::omp::Directive::OMPD_requires);

// Gather information from the clauses.
Flags flags;
std::optional<common::OmpMemoryOrderType> memOrder;
RequiresClauses reqs;
const common::OmpMemoryOrderType *memOrder{nullptr};
for (const parser::OmpClause &clause : x.v.Clauses().v) {
flags |= common::visit(
using OmpClause = parser::OmpClause;
reqs |= common::visit(
common::visitors{
[&memOrder](
const parser::OmpClause::AtomicDefaultMemOrder &atomic) {
memOrder = atomic.v.v;
return Flags{};
},
[](const parser::OmpClause::ReverseOffload &) {
return Flags{Requires::ReverseOffload};
},
[](const parser::OmpClause::UnifiedAddress &) {
return Flags{Requires::UnifiedAddress};
[&](const OmpClause::AtomicDefaultMemOrder &atomic) {
memOrder = &atomic.v.v;
return RequiresClauses{};
},
[](const parser::OmpClause::UnifiedSharedMemory &) {
return Flags{Requires::UnifiedSharedMemory};
},
[](const parser::OmpClause::DynamicAllocators &) {
return Flags{Requires::DynamicAllocators};
[&](auto &&s) {
using TypeS = llvm::remove_cvref_t<decltype(s)>;
if constexpr ( //
std::is_same_v<TypeS, OmpClause::DynamicAllocators> ||
std::is_same_v<TypeS, OmpClause::ReverseOffload> ||
std::is_same_v<TypeS, OmpClause::UnifiedAddress> ||
std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) {
return RequiresClauses{clause.Id()};
} else {
return RequiresClauses{};
}
},
[](const auto &) { return Flags{}; }},
},
clause.u);
}
// Merge clauses into parents' symbols details.
AddOmpRequiresToScope(currScope(), flags, memOrder);
AddOmpRequiresToScope(currScope(), &reqs, memOrder);
return true;
}
void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); }
Expand Down Expand Up @@ -1001,8 +1016,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {

std::int64_t ordCollapseLevel{0};

void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags,
std::optional<common::OmpMemoryOrderType>);
void AddOmpRequiresToScope(Scope &,
const WithOmpDeclarative::RequiresClauses *,
const common::OmpMemoryOrderType *);
void IssueNonConformanceWarning(llvm::omp::Directive D,
parser::CharBlock source, unsigned EmitFromVersion);

Expand Down Expand Up @@ -3309,86 +3325,6 @@ void ResolveOmpParts(
}
}

void ResolveOmpTopLevelParts(
SemanticsContext &context, const parser::Program &program) {
if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
return;
}

// Gather REQUIRES clauses from all non-module top-level program unit symbols,
// combine them together ensuring compatibility and apply them to all these
// program units. Modules are skipped because their REQUIRES clauses should be
// propagated via USE statements instead.
WithOmpDeclarative::RequiresFlags combinedFlags;
std::optional<common::OmpMemoryOrderType> combinedMemOrder;

// Function to go through non-module top level program units and extract
// REQUIRES information to be processed by a function-like argument.
auto processProgramUnits{[&](auto processFn) {
for (const parser::ProgramUnit &unit : program.v) {
if (!std::holds_alternative<common::Indirection<parser::Module>>(
unit.u) &&
!std::holds_alternative<common::Indirection<parser::Submodule>>(
unit.u) &&
!std::holds_alternative<
common::Indirection<parser::CompilerDirective>>(unit.u)) {
Symbol *symbol{common::visit(
[&context](auto &x) {
Scope *scope = GetScope(context, x.value());
return scope ? scope->symbol() : nullptr;
},
unit.u)};
// FIXME There is no symbol defined for MainProgram units in certain
// circumstances, so REQUIRES information has no place to be stored in
// these cases.
if (!symbol) {
continue;
}
common::visit(
[&](auto &details) {
if constexpr (std::is_convertible_v<decltype(&details),
WithOmpDeclarative *>) {
processFn(*symbol, details);
}
},
symbol->details());
}
}
}};

// Combine global REQUIRES information from all program units except modules
// and submodules.
processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) {
if (const WithOmpDeclarative::RequiresFlags *
flags{details.ompRequires()}) {
combinedFlags |= *flags;
}
if (const common::OmpMemoryOrderType *
memOrder{details.ompAtomicDefaultMemOrder()}) {
if (combinedMemOrder && *combinedMemOrder != *memOrder) {
context.Say(symbol.scope()->sourceRange(),
"Conflicting '%s' REQUIRES clauses found in compilation "
"unit"_err_en_US,
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
llvm::omp::Clause::OMPC_atomic_default_mem_order)
.str()));
}
combinedMemOrder = *memOrder;
}
});

// Update all program units except modules and submodules with the combined
// global REQUIRES information.
processProgramUnits([&](Symbol &, WithOmpDeclarative &details) {
if (combinedFlags.any()) {
details.set_ompRequires(combinedFlags);
}
if (combinedMemOrder) {
details.set_ompAtomicDefaultMemOrder(*combinedMemOrder);
}
});
}

static bool IsSymbolThreadprivate(const Symbol &symbol) {
if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) {
return details->symbol().test(Symbol::Flag::OmpThreadprivate);
Expand Down Expand Up @@ -3547,23 +3483,22 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source,
}

void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
WithOmpDeclarative::RequiresFlags flags,
std::optional<common::OmpMemoryOrderType> memOrder) {
const WithOmpDeclarative::RequiresClauses *reqs,
const common::OmpMemoryOrderType *memOrder) {
const Scope &programUnit{omp::GetProgramUnit(scope)};
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}};

if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) {
common::visit(
[&](auto &details) {
// Store clauses information into the symbol for the parent and
// enclosing modules, programs, functions and subroutines.
if constexpr (std::is_convertible_v<decltype(&details),
WithOmpDeclarative *>) {
if (flags.any()) {
if (const WithOmpDeclarative::RequiresFlags *otherFlags{
details.ompRequires()}) {
flags |= *otherFlags;
if (combinedReqs.any()) {
if (const RequiresClauses *otherReqs{details.ompRequires()}) {
combinedReqs |= *otherReqs;
}
details.set_ompRequires(flags);
details.set_ompRequires(combinedReqs);
}
if (memOrder) {
if (details.has_ompAtomicDefaultMemOrder() &&
Expand Down
2 changes: 0 additions & 2 deletions flang/lib/Semantics/resolve-directives.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,5 @@ class SemanticsContext;
void ResolveAccParts(
SemanticsContext &, const parser::ProgramUnit &, Scope *topScope);
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);

} // namespace Fortran::semantics
#endif
3 changes: 0 additions & 3 deletions flang/lib/Semantics/resolve-names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) {
CHECK(!attrs_);
CHECK(!cudaDataAttr_);
CHECK(!GetDeclTypeSpec());
// Top-level resolution to propagate information across program units after
// each of them has been resolved separately.
ResolveOmpTopLevelParts(context(), x);
}

// A singleton instance of the scope -> IMPLICIT rules mapping is
Expand Down
Loading
Loading