From c9ee93f1c28e9d4518ec842f66f497bf0911c9d5 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Tue, 26 Mar 2024 16:31:34 +0000 Subject: [PATCH 01/13] [MLIR][OpenMP] Group clause operands into structures This patch introduces a set of composable structures grouping the MLIR operands associated to each OpenMP clause. This makes it easier to keep the MLIR representation for the same clause consistent throughout all operations that accept it. The relevant clause operand structures are grouped into per-operation structures using a mixin pattern and used to define new operation constructors. These constructors can be used to avoid having to get the order of a possibly large list of operands right. Missing clauses are documented as TODOs, as well as operands which are part of the relevant operation's operand structure but cannot be attached to the associated operation yet, due to missing op arguments to its MLIR definition. A follow-up patch will update Flang lowering to make use of these structures, simplifying the passing of information from clause processing to operation- generating functions and also simplifying the creation of operations through the use of the new operation constructors. --- .../Dialect/OpenMP/OpenMPClauseOperands.h | 300 ++++++++++++++++++ .../mlir/Dialect/OpenMP/OpenMPDialect.h | 7 +- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 72 ++++- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 226 ++++++++++++- 4 files changed, 595 insertions(+), 10 deletions(-) create mode 100644 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h new file mode 100644 index 0000000000000..6454076f7593b --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -0,0 +1,300 @@ +//===-- OpenMPClauseOperands.h ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the structures defining MLIR operands associated with each +// OpenMP clause, and structures grouping the appropriate operands for each +// construct. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ +#define MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc" + +namespace mlir { +namespace omp { + +//===----------------------------------------------------------------------===// +// Mixin structures defining MLIR operands associated with each OpenMP clause. +//===----------------------------------------------------------------------===// + +struct AlignedClauseOps { + llvm::SmallVector alignedVars; + llvm::SmallVector alignmentAttrs; +}; + +struct AllocateClauseOps { + llvm::SmallVector allocatorVars, allocateVars; +}; + +struct CollapseClauseOps { + llvm::SmallVector loopLBVar, loopUBVar, loopStepVar; +}; + +struct CopyprivateClauseOps { + llvm::SmallVector copyprivateVars; + llvm::SmallVector copyprivateFuncs; +}; + +struct DependClauseOps { + llvm::SmallVector dependTypeAttrs; + llvm::SmallVector dependVars; +}; + +struct DeviceClauseOps { + Value deviceVar; +}; + +struct DeviceTypeClauseOps { + // The default capture type. + DeclareTargetDeviceType deviceType = DeclareTargetDeviceType::any; +}; + +struct DistScheduleClauseOps { + UnitAttr distScheduleStaticAttr; + Value distScheduleChunkSizeVar; +}; + +struct DoacrossClauseOps { + llvm::SmallVector doacrossVectorVars; + ClauseDependAttr doacrossDependTypeAttr; + IntegerAttr doacrossNumLoopsAttr; +}; + +struct FinalClauseOps { + Value finalVar; +}; + +struct GrainsizeClauseOps { + Value grainsizeVar; +}; + +struct HintClauseOps { + IntegerAttr hintAttr; +}; + +struct IfClauseOps { + Value ifVar; +}; + +struct InReductionClauseOps { + llvm::SmallVector inReductionVars; + llvm::SmallVector inReductionDeclSymbols; +}; + +struct LinearClauseOps { + llvm::SmallVector linearVars, linearStepVars; +}; + +struct LoopRelatedOps { + UnitAttr loopInclusiveAttr; +}; + +struct MapClauseOps { + llvm::SmallVector mapVars; +}; + +struct MergeableClauseOps { + UnitAttr mergeableAttr; +}; + +struct NameClauseOps { + StringAttr nameAttr; +}; + +struct NogroupClauseOps { + UnitAttr nogroupAttr; +}; + +struct NontemporalClauseOps { + llvm::SmallVector nontemporalVars; +}; + +struct NowaitClauseOps { + UnitAttr nowaitAttr; +}; + +struct NumTasksClauseOps { + Value numTasksVar; +}; + +struct NumTeamsClauseOps { + Value numTeamsLowerVar, numTeamsUpperVar; +}; + +struct NumThreadsClauseOps { + Value numThreadsVar; +}; + +struct OrderClauseOps { + ClauseOrderKindAttr orderAttr; +}; + +struct OrderedClauseOps { + IntegerAttr orderedAttr; +}; + +struct ParallelizationLevelClauseOps { + UnitAttr parLevelSimdAttr; +}; + +struct PriorityClauseOps { + Value priorityVar; +}; + +struct PrivateClauseOps { + // SSA values that correspond to "original" values being privatized. + // They refer to the SSA value outside the OpenMP region from which a clone is + // created inside the region. + llvm::SmallVector privateVars; + // The list of symbols referring to delayed privatizer ops (i.e. `omp.private` + // ops). + llvm::SmallVector privatizers; +}; + +struct ProcBindClauseOps { + ClauseProcBindKindAttr procBindKindAttr; +}; + +struct ReductionClauseOps { + llvm::SmallVector reductionVars; + llvm::SmallVector reductionDeclSymbols; + UnitAttr reductionByRefAttr; +}; + +struct SafelenClauseOps { + IntegerAttr safelenAttr; +}; + +struct ScheduleClauseOps { + ClauseScheduleKindAttr scheduleValAttr; + ScheduleModifierAttr scheduleModAttr; + Value scheduleChunkVar; + UnitAttr scheduleSimdAttr; +}; + +struct SimdlenClauseOps { + IntegerAttr simdlenAttr; +}; + +struct TaskReductionClauseOps { + llvm::SmallVector taskReductionVars; + llvm::SmallVector taskReductionDeclSymbols; +}; + +struct ThreadLimitClauseOps { + Value threadLimitVar; +}; + +struct UntiedClauseOps { + UnitAttr untiedAttr; +}; + +struct UseDeviceClauseOps { + llvm::SmallVector useDevicePtrVars, useDeviceAddrVars; +}; + +//===----------------------------------------------------------------------===// +// Structures defining clause operands associated with each OpenMP leaf +// construct. +// +// These mirror the arguments expected by the corresponding OpenMP MLIR ops. +//===----------------------------------------------------------------------===// + +namespace detail { +template +struct Clauses : public Mixins... {}; +} // namespace detail + +using CriticalClauseOps = detail::Clauses; + +// TODO `indirect` clause. +using DeclareTargetClauseOps = detail::Clauses; + +using DistributeClauseOps = + detail::Clauses; + +// TODO `filter` clause. +using MaskedClauseOps = detail::Clauses<>; + +using OrderedOpClauseOps = detail::Clauses; + +using OrderedRegionClauseOps = detail::Clauses; + +using ParallelClauseOps = + detail::Clauses; + +using SectionsClauseOps = detail::Clauses; + +// TODO `linear` clause. +using SimdLoopClauseOps = + detail::Clauses; + +using SingleClauseOps = detail::Clauses; + +// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators` +// clauses. +using TargetClauseOps = + detail::Clauses; + +using TargetDataClauseOps = detail::Clauses; + +using TargetEnterExitUpdateDataClauseOps = + detail::Clauses; + +// TODO `affinity`, `detach` clauses. +using TaskClauseOps = + detail::Clauses; + +using TaskgroupClauseOps = + detail::Clauses; + +using TaskloopClauseOps = + detail::Clauses; + +using TaskwaitClauseOps = detail::Clauses; + +using TeamsClauseOps = + detail::Clauses; + +using WsloopClauseOps = + detail::Clauses; + +} // namespace omp +} // namespace mlir + +#endif // MLIR_DIALECT_OPENMP_OPENMPCLAUSEOPERANDS_H_ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h index 23509c5b60701..c656bdc870976 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -26,11 +26,10 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc" -#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc" -#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.h.inc" +#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h" + +#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index f33942b3c7c02..2643348d66869 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -287,7 +287,8 @@ def ParallelOp : OpenMP_Op<"parallel", [ let regions = (region AnyRegion:$region); let builders = [ - OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> + OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins CArg<"const ParallelClauseOps &">:$clauses)> ]; let extraClassDeclaration = [{ /// Returns the number of reduction variables. @@ -362,6 +363,10 @@ def TeamsOp : OpenMP_Op<"teams", [ let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TeamsClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to` @@ -451,6 +456,10 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments, let regions = (region SizedRegion<1>:$region); + let builders = [ + OpBuilder<(ins CArg<"const SectionsClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `reduction` `(` custom( @@ -495,6 +504,10 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const SingleClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`allocate` `(` custom( @@ -601,6 +614,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound, "ValueRange":$step, CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)> ]; let regions = (region AnyRegion:$region); @@ -698,6 +712,11 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments, ); let regions = (region AnyRegion:$region); + + let builders = [ + OpBuilder<(ins CArg<"const SimdLoopClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`aligned` `(` custom($aligned_vars, type($aligned_vars), @@ -781,6 +800,10 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const DistributeClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`dist_schedule_static` $dist_schedule_static |`chunk_size` `(` $chunk_size `:` type($chunk_size) `)` @@ -883,6 +906,9 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments, Variadic:$allocate_vars, Variadic:$allocators_vars); let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskClauseOps &">:$clauses)> + ]; let assemblyFormat = [{ oilist(`if` `(` $if_expr `)` |`final` `(` $final_expr `)` @@ -1037,6 +1063,10 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskloopClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `)` |`final` `(` $final_expr `)` @@ -1106,6 +1136,10 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TaskgroupClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`task_reduction` `(` custom( @@ -1432,6 +1466,10 @@ def TargetDataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments, let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TargetDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1486,6 +1524,10 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data", UnitAttr:$nowait, Variadic:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1540,6 +1582,10 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data", UnitAttr:$nowait, Variadic:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1596,6 +1642,10 @@ def TargetUpdateOp: OpenMP_Op<"target_update", [AttrSizedOperandSegments, UnitAttr:$nowait, Variadic:$map_operands); + let builders = [ + OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist(`if` `(` $if_expr `:` type($if_expr) `)` | `device` `(` $device `:` type($device) `)` @@ -1649,6 +1699,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const TargetClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ oilist( `if` `(` $if_expr `)` | `device` `(` $device `:` type($device) `)` @@ -1693,6 +1747,10 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> { let arguments = (ins SymbolNameAttr:$sym_name, DefaultValuedAttr:$hint_val); + let builders = [ + OpBuilder<(ins CArg<"const CriticalClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ $sym_name oilist(`hint` `(` custom($hint_val) `)`) attr-dict @@ -1773,6 +1831,10 @@ def OrderedOp : OpenMP_Op<"ordered"> { ConfinedAttr, [IntMinValue<0>]>:$num_loops_val, Variadic:$depend_vec_vars); + let builders = [ + OpBuilder<(ins CArg<"const OrderedOpClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ ( `depend_type` `` $depend_type_val^ )? ( `depend_vec` `(` $depend_vec_vars^ `:` type($depend_vec_vars) `)` )? @@ -1797,6 +1859,10 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region"> { let regions = (region AnyRegion:$region); + let builders = [ + OpBuilder<(ins CArg<"const OrderedRegionClauseOps &">:$clauses)> + ]; + let assemblyFormat = [{ ( `simd` $simd^ )? $region attr-dict}]; let hasVerifier = 1; } @@ -1812,6 +1878,10 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> { of the current task. }]; + let builders = [ + OpBuilder<(ins CArg<"const TaskwaitClauseOps &">:$clauses)> + ]; + let assemblyFormat = "attr-dict"; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index bf5875071e0dc..28869c1ddfb3f 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -41,6 +41,11 @@ using namespace mlir; using namespace mlir::omp; +static ArrayAttr makeArrayAttr(MLIRContext *context, + llvm::ArrayRef attrs) { + return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs); +} + namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel static LogicalResult verifyPrivateVarList(OpType &op) { auto privateVars = op.getPrivateVars(); @@ -1280,6 +1364,17 @@ static bool opInGlobalImplicitParallelRegion(Operation *op) { return true; } +void TeamsOp::build(OpBuilder &builder, OperationState &state, + const TeamsClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + TeamsOp::build(builder, state, clauses.numTeamsLowerVar, + clauses.numTeamsUpperVar, clauses.ifVar, + clauses.threadLimitVar, clauses.allocateVars, + clauses.allocatorVars, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols)); +} + LogicalResult TeamsOp::verify() { // Check parent region // TODO If nested inside of a target region, also check that it does not @@ -1312,9 +1407,19 @@ LogicalResult TeamsOp::verify() { } //===----------------------------------------------------------------------===// -// Verifier for SectionsOp +// SectionsOp //===----------------------------------------------------------------------===// +void SectionsOp::build(OpBuilder &builder, OperationState &state, + const SectionsClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + SectionsOp::build(builder, state, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), + clauses.allocateVars, clauses.allocatorVars, + clauses.nowaitAttr); +} + LogicalResult SectionsOp::verify() { if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( @@ -1334,6 +1439,20 @@ LogicalResult SectionsOp::verifyRegions() { return success(); } +//===----------------------------------------------------------------------===// +// SingleOp +//===----------------------------------------------------------------------===// + +void SingleOp::build(OpBuilder &builder, OperationState &state, + const SingleClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, privatizers. + SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.copyprivateVars, + makeArrayAttr(ctx, clauses.copyprivateFuncs), + clauses.nowaitAttr); +} + LogicalResult SingleOp::verify() { // Check for allocate clause restrictions if (getAllocateVars().size() != getAllocatorsVars().size()) @@ -1481,9 +1600,21 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, } //===----------------------------------------------------------------------===// -// Verifier for Simd construct [2.9.3.1] +// Simd construct [2.9.3.1] //===----------------------------------------------------------------------===// +void SimdLoopOp::build(OpBuilder &builder, OperationState &state, + const SimdLoopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars, + // privatizers, reductionDeclSymbols. + SimdLoopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs), + clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr, + clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr); +} + LogicalResult SimdLoopOp::verify() { if (this->getLowerBound().empty()) { return emitOpError() << "empty lowerbound for simd loop operation"; @@ -1504,9 +1635,17 @@ LogicalResult SimdLoopOp::verify() { } //===----------------------------------------------------------------------===// -// Verifier for Distribute construct [2.9.4.1] +// Distribute construct [2.9.4.1] //===----------------------------------------------------------------------===// +void DistributeOp::build(OpBuilder &builder, OperationState &state, + const DistributeClauseOps &clauses) { + // TODO Store clauses in op: privateVars, privatizers. + DistributeOp::build(builder, state, clauses.distScheduleStaticAttr, + clauses.distScheduleChunkSizeVar, clauses.allocateVars, + clauses.allocatorVars, clauses.orderAttr); +} + LogicalResult DistributeOp::verify() { if (this->getChunkSize() && !this->getDistScheduleStatic()) return emitOpError() << "chunk size set without " @@ -1607,6 +1746,19 @@ LogicalResult ReductionOp::verify() { //===----------------------------------------------------------------------===// // TaskOp //===----------------------------------------------------------------------===// + +void TaskOp::build(OpBuilder &builder, OperationState &state, + const TaskClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: privateVars, privatizers. + TaskOp::build( + builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr, + clauses.mergeableAttr, clauses.inReductionVars, + makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar, + makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars, + clauses.allocateVars, clauses.allocatorVars); +} + LogicalResult TaskOp::verify() { LogicalResult verifyDependVars = verifyDependVarList(*this, getDepends(), getDependVars()); @@ -1619,6 +1771,15 @@ LogicalResult TaskOp::verify() { //===----------------------------------------------------------------------===// // TaskgroupOp //===----------------------------------------------------------------------===// + +void TaskgroupOp::build(OpBuilder &builder, OperationState &state, + const TaskgroupClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + TaskgroupOp::build(builder, state, clauses.taskReductionVars, + makeArrayAttr(ctx, clauses.taskReductionDeclSymbols), + clauses.allocateVars, clauses.allocatorVars); +} + LogicalResult TaskgroupOp::verify() { return verifyReductionVarList(*this, getTaskReductions(), getTaskReductionVars()); @@ -1627,6 +1788,21 @@ LogicalResult TaskgroupOp::verify() { //===----------------------------------------------------------------------===// // TaskloopOp //===----------------------------------------------------------------------===// + +void TaskloopOp::build(OpBuilder &builder, OperationState &state, + const TaskloopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. + TaskloopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar, + clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars, + makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar, + clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar, + clauses.numTasksVar, clauses.nogroupAttr); +} + SmallVector TaskloopOp::getAllReductionVars() { SmallVector allReductionNvars(getInReductionVars().begin(), getInReductionVars().end()); @@ -1680,14 +1856,33 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, state.addAttributes(attributes); } +void WsloopOp::build(OpBuilder &builder, OperationState &state, + const WsloopClauseOps &clauses) { + MLIRContext *ctx = builder.getContext(); + // TODO Store clauses in op: allocateVars, allocatorVars, privateVars, + // privatizers. + WsloopOp::build( + builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, + clauses.linearVars, clauses.linearStepVars, clauses.reductionVars, + makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr, + clauses.scheduleChunkVar, clauses.scheduleModAttr, + clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr, + clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr); +} + LogicalResult WsloopOp::verify() { return verifyReductionVarList(*this, getReductions(), getReductionVars()); } //===----------------------------------------------------------------------===// -// Verifier for critical construct (2.17.1) +// Critical construct (2.17.1) //===----------------------------------------------------------------------===// +void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state, + const CriticalClauseOps &clauses) { + CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr); +} + LogicalResult CriticalDeclareOp::verify() { return verifySynchronizationHint(*this, getHintVal()); } @@ -1707,9 +1902,15 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } //===----------------------------------------------------------------------===// -// Verifier for ordered construct +// Ordered construct //===----------------------------------------------------------------------===// +void OrderedOp::build(OpBuilder &builder, OperationState &state, + const OrderedOpClauseOps &clauses) { + OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr, + clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars); +} + LogicalResult OrderedOp::verify() { auto container = (*this)->getParentOfType(); if (!container || !container.getOrderedValAttr() || @@ -1726,6 +1927,11 @@ LogicalResult OrderedOp::verify() { return success(); } +void OrderedRegionOp::build(OpBuilder &builder, OperationState &state, + const OrderedRegionClauseOps &clauses) { + OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr); +} + LogicalResult OrderedRegionOp::verify() { // TODO: The code generation for ordered simd directive is not supported yet. if (getSimd()) @@ -1742,6 +1948,16 @@ LogicalResult OrderedRegionOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TaskwaitOp +//===----------------------------------------------------------------------===// + +void TaskwaitOp::build(OpBuilder &builder, OperationState &state, + const TaskwaitClauseOps &clauses) { + // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr. + TaskwaitOp::build(builder, state); +} + //===----------------------------------------------------------------------===// // Verifier for AtomicReadOp //===----------------------------------------------------------------------===// From 7af7e9d13fc2134e76bb532bfa4313aa3df17924 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Tue, 26 Mar 2024 16:46:56 +0000 Subject: [PATCH 02/13] [Flang][OpenMP][Lower] Use clause operand structures This patch updates Flang lowering to use the new set of OpenMP clause operand structures and their groupings into directive-specific sets of clause operands. It simplifies the passing of information from the clause processor and the creation of operations. The `DataSharingProcessor` is slightly modified to not hold delayed privatization state. Instead, optional arguments are added to `processStep1` which are only passed when delayed privatization is used. This enables using the clause operand structure for `private` and removes the need for the ad-hoc `DelayedPrivatizationInfo` structure. The processing of the `schedule` clause is updated to process the `chunk` modifier rather than requiring two separate calls to the `ClauseProcessor`. Lowering of a block-associated `ordered` construct is updated to emit a TODO error if the `simd` clause is specified, since it is not currently supported by the `ClauseProcessor` or later compilation stages. Removed processing of `schedule` from `omp.simdloop`, as it doesn't apply to `simd` constructs. --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 261 +++++---- flang/lib/Lower/OpenMP/ClauseProcessor.h | 105 ++-- .../lib/Lower/OpenMP/DataSharingProcessor.cpp | 38 +- flang/lib/Lower/OpenMP/DataSharingProcessor.h | 45 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 517 +++++++----------- 5 files changed, 428 insertions(+), 538 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 0a57a1496289f..ee1f6c2fbc7e8 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter, ifVal); } -static void -addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const omp::ObjectList &objects, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) { +static void addUseDeviceClause( + Fortran::lower::AbstractConverter &converter, + const omp::ObjectList &objects, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSyms) { genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); @@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter, useDeviceLocs.push_back(operand.getLoc()); } for (const omp::Object &object : objects) - useDeviceSymbols.push_back(object.id()); + useDeviceSyms.push_back(object.id()); } static void convertLoopBounds(Fortran::lower::AbstractConverter &converter, mlir::Location loc, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, std::size_t loopVarTypeSize) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // The types of lower bound, upper bound, and step are converted into the // type of the loop variable if necessary. mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) { - lowerBound[it] = - firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]); - upperBound[it] = - firOpBuilder.createConvert(loc, loopVarType, upperBound[it]); - step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]); + for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) { + result.loopLBVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]); + result.loopUBVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]); + result.loopStepVar[it] = + firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]); } } @@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter, bool ClauseProcessor::processCollapse( mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, llvm::SmallVectorImpl &iv) const { bool found = false; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse( std::get_if(&loopControl->u); assert(bounds && "Expected bounds for worksharing do loop"); Fortran::lower::StatementContext stmtCtx; - lowerBound.push_back(fir::getBase(converter.genExprValue( + result.loopLBVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperBound.push_back(fir::getBase(converter.genExprValue( + result.loopUBVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); if (bounds->step) { - step.push_back(fir::getBase(converter.genExprValue( + result.loopStepVar.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); } else { // If `step` is not present, assume it as `1`. - step.push_back(firOpBuilder.createIntegerConstant( + result.loopStepVar.push_back(firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getIntegerType(32), 1)); } iv.push_back(bounds->name.thing.symbol); @@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse( &*std::next(doConstructEval->getNestedEvaluations().begin()); } while (collapseValue > 0); - convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step, - loopVarTypeSize); + convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); return found; } @@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const { } bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { + mlir::omp::DeviceClauseOps &result) const { const Fortran::parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { mlir::Location clauseLocation = converter.genLocation(*source); @@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx, } } const auto &deviceExpr = std::get(clause->t); - result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); + result.deviceVar = + fir::getBase(converter.genExprValue(deviceExpr, stmtCtx)); return true; } return false; } bool ClauseProcessor::processDeviceType( - mlir::omp::DeclareTargetDeviceType &result) const { + mlir::omp::DeviceTypeClauseOps &result) const { if (auto *clause = findUniqueClause()) { // Case: declare target ... device_type(any | host | nohost) switch (clause->v) { case omp::clause::DeviceType::DeviceTypeDescription::Nohost: - result = mlir::omp::DeclareTargetDeviceType::nohost; + result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost; break; case omp::clause::DeviceType::DeviceTypeDescription::Host: - result = mlir::omp::DeclareTargetDeviceType::host; + result.deviceType = mlir::omp::DeclareTargetDeviceType::host; break; case omp::clause::DeviceType::DeviceTypeDescription::Any: - result = mlir::omp::DeclareTargetDeviceType::any; + result.deviceType = mlir::omp::DeclareTargetDeviceType::any; break; } return true; @@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType( } bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { + mlir::omp::FinalClauseOps &result) const { const Fortran::parser::CharBlock *source = nullptr; if (auto *clause = findUniqueClause(&source)) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx, mlir::Value finalVal = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); - result = firOpBuilder.createConvert(clauseLocation, - firOpBuilder.getI1Type(), finalVal); + result.finalVar = firOpBuilder.createConvert( + clauseLocation, firOpBuilder.getI1Type(), finalVal); return true; } return false; } -bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(hintValue); + result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue); return true; } return false; } -bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processMergeable( + mlir::omp::MergeableClauseOps &result) const { + return markClauseOccurrence(result.mergeableAttr); } -bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const { + return markClauseOccurrence(result.nowaitAttr); } -bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { +bool ClauseProcessor::processNumTeams( + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::NumTeamsClauseOps &result) const { // TODO Get lower and upper bounds for num_teams when parser is updated to // accept both. if (auto *clause = findUniqueClause()) { // auto lowerBound = std::get>(clause->t); auto &upperBound = std::get(clause->t); - result = fir::getBase(converter.genExprValue(upperBound, stmtCtx)); + result.numTeamsUpperVar = + fir::getBase(converter.genExprValue(upperBound, stmtCtx)); return true; } return false; } bool ClauseProcessor::processNumThreads( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::NumThreadsClauseOps &result) const { if (auto *clause = findUniqueClause()) { // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`. - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.numThreadsVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } -bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processOrdered( + mlir::omp::OrderedClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); int64_t orderedClauseValue = 0l; if (clause->v.has_value()) orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v); - result = firOpBuilder.getI64IntegerAttr(orderedClauseValue); + result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue); return true; } return false; } -bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const { +bool ClauseProcessor::processPriority( + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::PriorityClauseOps &result) const { if (auto *clause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.priorityVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } bool ClauseProcessor::processProcBind( - mlir::omp::ClauseProcBindKindAttr &result) const { + mlir::omp::ProcBindClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - result = genProcBindKindAttr(firOpBuilder, *clause); + result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause); return true; } return false; } -bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processSafelen( + mlir::omp::SafelenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional safelenVal = Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(*safelenVal); + result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal); return true; } return false; } bool ClauseProcessor::processSchedule( - mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ScheduleClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::MLIRContext *context = firOpBuilder.getContext(); @@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule( break; } - mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause); + result.scheduleValAttr = + mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); + mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause); if (scheduleModifier != mlir::omp::ScheduleModifier::none) - modifierAttr = + result.scheduleModAttr = mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier); if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none) - simdModifierAttr = firOpBuilder.getUnitAttr(); + result.scheduleSimdAttr = firOpBuilder.getUnitAttr(); - valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind); - return true; - } - return false; -} - -bool ClauseProcessor::processScheduleChunk( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { - if (auto *clause = findUniqueClause()) { if (const auto &chunkExpr = std::get(clause->t)) - result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); + result.scheduleChunkVar = + fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx)); + return true; } return false; } -bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const { +bool ClauseProcessor::processSimdlen( + mlir::omp::SimdlenClauseOps &result) const { if (auto *clause = findUniqueClause()) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const std::optional simdlenVal = Fortran::evaluate::ToInt64(clause->v); - result = firOpBuilder.getI64IntegerAttr(*simdlenVal); + result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal); return true; } return false; } bool ClauseProcessor::processThreadLimit( - Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const { + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ThreadLimitClauseOps &result) const { if (auto *clause = findUniqueClause()) { - result = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + result.threadLimitVar = + fir::getBase(converter.genExprValue(clause->v, stmtCtx)); return true; } return false; } -bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { - return markClauseOccurrence(result); +bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const { + return markClauseOccurrence(result.untiedAttr); } //===----------------------------------------------------------------------===// @@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { //===----------------------------------------------------------------------===// bool ClauseProcessor::processAllocate( - llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const { + mlir::omp::AllocateClauseOps &result) const { return findRepeatableClause( [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, clause, allocatorOperands, - allocateOperands); + genAllocateClause(converter, clause, result.allocatorVars, + result.allocateVars); }); } @@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter, return funcOp; } -bool ClauseProcessor::processCopyPrivate( +bool ClauseProcessor::processCopyprivate( mlir::Location currentLocation, - llvm::SmallVectorImpl ©PrivateVars, - llvm::SmallVectorImpl ©PrivateFuncs) const { + mlir::omp::CopyprivateClauseOps &result) const { auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) { mlir::Value symVal = converter.getSymbolAddress(*sym); auto declOp = symVal.getDefiningOp(); @@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate( cpVar = alloca; } - copyPrivateVars.push_back(cpVar); + result.copyprivateVars.push_back(cpVar); mlir::func::FuncOp funcOp = createCopyFunc(currentLocation, converter, cpVar.getType(), attrs); - copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); + result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); }; bool hasCopyPrivate = findRepeatableClause( @@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate( return hasCopyPrivate; } -bool ClauseProcessor::processDepend( - llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const { +bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); return findRepeatableClause( @@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend( mlir::omp::ClauseTaskDependAttr dependTypeOperand = genDependKindAttr(firOpBuilder, kind); - dependTypeOperands.append(objects.size(), dependTypeOperand); + result.dependTypeAttrs.append(objects.size(), dependTypeOperand); for (const omp::Object &object : objects) { assert(object.ref() && "Expecting designator"); @@ -746,14 +744,14 @@ bool ClauseProcessor::processDepend( Fortran::semantics::Symbol *sym = object.id(); const mlir::Value variable = converter.getSymbolAddress(*sym); - dependOperands.push_back(variable); + result.dependVars.push_back(variable); } }); } bool ClauseProcessor::processIf( omp::clause::If::DirectiveNameModifier directiveName, - mlir::Value &result) const { + mlir::omp::IfClauseOps &result) const { bool found = false; findRepeatableClause( [&](const omp::clause::If &clause, @@ -764,7 +762,7 @@ bool ClauseProcessor::processIf( // Assume that, at most, a single 'if' clause will be applicable to the // given directive. if (operand) { - result = operand; + result.ifVar = operand; found = true; } }); @@ -807,12 +805,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, bool ClauseProcessor::processMap( mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes, + Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms, llvm::SmallVectorImpl *mapSymLocs, - llvm::SmallVectorImpl *mapSymbols) - const { + llvm::SmallVectorImpl *mapSymTypes) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); return findRepeatableClause( [&](const omp::clause::Map &clause, @@ -887,25 +883,23 @@ bool ClauseProcessor::processMap( mapTypeBits), mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - mapOperands.push_back(mapOp); - if (mapSymTypes) - mapSymTypes->push_back(symAddr.getType()); + result.mapVars.push_back(mapOp); + + if (mapSyms) + mapSyms->push_back(object.id()); if (mapSymLocs) mapSymLocs->push_back(symAddr.getLoc()); - - if (mapSymbols) - mapSymbols->push_back(object.id()); + if (mapSymTypes) + mapSymTypes->push_back(symAddr.getType()); } }); } bool ClauseProcessor::processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl &outReductionVars, - llvm::SmallVectorImpl &outReductionTypes, - llvm::SmallVectorImpl &outReductionDeclSymbols, - llvm::SmallVectorImpl - *outReductionSymbols) const { + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, + llvm::SmallVectorImpl *outReductionTypes, + llvm::SmallVectorImpl *outReductionSyms) + const { return findRepeatableClause( [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { @@ -915,30 +909,31 @@ bool ClauseProcessor::processReduction( // whether to do the reduction byref. llvm::SmallVector reductionVars; llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; + llvm::SmallVector reductionSyms; ReductionProcessor rp; rp.addDeclareReduction(currentLocation, converter, clause, reductionVars, reductionDeclSymbols, - outReductionSymbols ? &reductionSymbols - : nullptr); + outReductionSyms ? &reductionSyms : nullptr); // Copy local lists into the output. - llvm::copy(reductionVars, std::back_inserter(outReductionVars)); + llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reductionDeclSymbols, - std::back_inserter(outReductionDeclSymbols)); - if (outReductionSymbols) - llvm::copy(reductionSymbols, - std::back_inserter(*outReductionSymbols)); - - outReductionTypes.reserve(outReductionTypes.size() + - reductionVars.size()); - llvm::transform(reductionVars, std::back_inserter(outReductionTypes), - [](mlir::Value v) { return v.getType(); }); + std::back_inserter(result.reductionDeclSymbols)); + + if (outReductionTypes) { + outReductionTypes->reserve(outReductionTypes->size() + + reductionVars.size()); + llvm::transform(reductionVars, std::back_inserter(*outReductionTypes), + [](mlir::Value v) { return v.getType(); }); + } + + if (outReductionSyms) + llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms)); }); } bool ClauseProcessor::processSectionsReduction( - mlir::Location currentLocation) const { + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const { return findRepeatableClause( [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); @@ -967,30 +962,30 @@ bool ClauseProcessor::processEnter( } bool ClauseProcessor::processUseDeviceAddr( - llvm::SmallVectorImpl &operands, + mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) + llvm::SmallVectorImpl &useDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, - useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars, + useDeviceTypes, useDeviceLocs, useDeviceSyms); }); } bool ClauseProcessor::processUseDevicePtr( - llvm::SmallVectorImpl &operands, + mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl &useDeviceSymbols) + llvm::SmallVectorImpl &useDeviceSyms) const { return findRepeatableClause( [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, - useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, result.useDevicePtrVars, + useDeviceTypes, useDeviceLocs, useDeviceSyms); }); } } // namespace omp diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index c0c603feb296a..d933e0a913d2b 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -37,7 +37,7 @@ namespace omp { /// corresponding clause if it is present in the clause list. Otherwise, they /// will return `false` to signal that the clause was not found. /// -/// The intended use is of this class is to move clause processing outside of +/// The intended use of this class is to move clause processing outside of /// construct processing, since the same clauses can appear attached to /// different constructs and constructs can be combined, so that code /// duplication is minimized. @@ -56,94 +56,83 @@ class ClauseProcessor { // 'Unique' clauses: They can appear at most once in the clause list. bool processCollapse( mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, - llvm::SmallVectorImpl &lowerBound, - llvm::SmallVectorImpl &upperBound, - llvm::SmallVectorImpl &step, + mlir::omp::CollapseClauseOps &result, llvm::SmallVectorImpl &iv) const; bool processDefault() const; bool processDevice(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; + mlir::omp::DeviceClauseOps &result) const; + bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; bool processFinal(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processHint(mlir::IntegerAttr &result) const; - bool processMergeable(mlir::UnitAttr &result) const; - bool processNowait(mlir::UnitAttr &result) const; + mlir::omp::FinalClauseOps &result) const; + bool processHint(mlir::omp::HintClauseOps &result) const; + bool processMergeable(mlir::omp::MergeableClauseOps &result) const; + bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; + mlir::omp::NumTeamsClauseOps &result) const; bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processOrdered(mlir::IntegerAttr &result) const; + mlir::omp::NumThreadsClauseOps &result) const; + bool processOrdered(mlir::omp::OrderedClauseOps &result) const; bool processPriority(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; - bool processSafelen(mlir::IntegerAttr &result) const; - bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, - mlir::omp::ScheduleModifierAttr &modifierAttr, - mlir::UnitAttr &simdModifierAttr) const; - bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processSimdlen(mlir::IntegerAttr &result) const; + mlir::omp::PriorityClauseOps &result) const; + bool processProcBind(mlir::omp::ProcBindClauseOps &result) const; + bool processSafelen(mlir::omp::SafelenClauseOps &result) const; + bool processSchedule(Fortran::lower::StatementContext &stmtCtx, + mlir::omp::ScheduleClauseOps &result) const; + bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const; bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, - mlir::Value &result) const; - bool processUntied(mlir::UnitAttr &result) const; + mlir::omp::ThreadLimitClauseOps &result) const; + bool processUntied(mlir::omp::UntiedClauseOps &result) const; // 'Repeatable' clauses: They can appear multiple times in the clause list. - bool - processAllocate(llvm::SmallVectorImpl &allocatorOperands, - llvm::SmallVectorImpl &allocateOperands) const; + bool processAllocate(mlir::omp::AllocateClauseOps &result) const; bool processCopyin() const; - bool processCopyPrivate( - mlir::Location currentLocation, - llvm::SmallVectorImpl ©PrivateVars, - llvm::SmallVectorImpl ©PrivateFuncs) const; - bool processDepend(llvm::SmallVectorImpl &dependTypeOperands, - llvm::SmallVectorImpl &dependOperands) const; + bool processCopyprivate(mlir::Location currentLocation, + mlir::omp::CopyprivateClauseOps &result) const; + bool processDepend(mlir::omp::DependClauseOps &result) const; bool processEnter(llvm::SmallVectorImpl &result) const; bool processIf(omp::clause::If::DirectiveNameModifier directiveName, - mlir::Value &result) const; + mlir::omp::IfClauseOps &result) const; bool processLink(llvm::SmallVectorImpl &result) const; // This method is used to process a map clause. - // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to + // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to // store the original type, location and Fortran symbol for the map operands. // They may be used later on to create the block_arguments for some of the // target directives that require it. - bool processMap(mlir::Location currentLocation, - const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands, - llvm::SmallVectorImpl *mapSymTypes = nullptr, - llvm::SmallVectorImpl *mapSymLocs = nullptr, - llvm::SmallVectorImpl - *mapSymbols = nullptr) const; - bool - processReduction(mlir::Location currentLocation, - llvm::SmallVectorImpl &reductionVars, - llvm::SmallVectorImpl &reductionTypes, - llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl - *reductionSymbols = nullptr) const; - bool processSectionsReduction(mlir::Location currentLocation) const; + bool processMap( + mlir::Location currentLocation, const llvm::omp::Directive &directive, + Fortran::lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, + llvm::SmallVectorImpl *mapSyms = + nullptr, + llvm::SmallVectorImpl *mapSymLocs = nullptr, + llvm::SmallVectorImpl *mapSymTypes = nullptr) const; + bool processReduction( + mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result, + llvm::SmallVectorImpl *reductionTypes = nullptr, + llvm::SmallVectorImpl *reductionSyms = + nullptr) const; + bool processSectionsReduction(mlir::Location currentLocation, + mlir::omp::ReductionClauseOps &result) const; bool processTo(llvm::SmallVectorImpl &result) const; bool - processUseDeviceAddr(llvm::SmallVectorImpl &operands, + processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl - &useDeviceSymbols) const; + &useDeviceSyms) const; bool - processUseDevicePtr(llvm::SmallVectorImpl &operands, + processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl - &useDeviceSymbols) const; + &useDeviceSyms) const; template bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands); + mlir::omp::MapClauseOps &result); // Call this method for these clauses that should be supported but are not // implemented yet. It triggers a compilation error if any of the given @@ -185,7 +174,7 @@ class ClauseProcessor { template bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl &mapOperands) { + mlir::omp::MapClauseOps &result) { return findRepeatableClause( [&](const T &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); @@ -227,7 +216,7 @@ bool ClauseProcessor::processMotionClauses( mapTypeBits), mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); - mapOperands.push_back(mapOp); + result.mapVars.push_back(mapOp); } }); } diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index e114ab9f4548a..5a42e6a6aa417 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -23,11 +23,13 @@ namespace Fortran { namespace lower { namespace omp { -void DataSharingProcessor::processStep1() { +void DataSharingProcessor::processStep1( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { collectSymbolsForPrivatization(); collectDefaultSymbols(); - privatize(); - defaultPrivatize(); + privatize(clauseOps, privateSyms); + defaultPrivatize(clauseOps, privateSyms); insertBarrier(); } @@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() { } } -void DataSharingProcessor::privatize() { +void DataSharingProcessor::privatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { for (const Fortran::semantics::Symbol *sym : privatizedSymbols) { if (const auto *commonDet = sym->detailsIf()) { for (const auto &mem : commonDet->objects()) - doPrivatize(&*mem); + doPrivatize(&*mem, clauseOps, privateSyms); } else - doPrivatize(sym); + doPrivatize(sym, clauseOps, privateSyms); } } @@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) { } } -void DataSharingProcessor::defaultPrivatize() { +void DataSharingProcessor::defaultPrivatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { for (const Fortran::semantics::Symbol *sym : defaultSymbols) { if (!Fortran::semantics::IsProcedure(*sym) && !sym->GetUltimate().has() && @@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() { !symbolsInNestedRegions.contains(sym) && !symbolsInParentRegions.contains(sym) && !privatizedSymbols.contains(sym)) - doPrivatize(sym); + doPrivatize(sym, clauseOps, privateSyms); } } -void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) { +void DataSharingProcessor::doPrivatize( + const Fortran::semantics::Symbol *sym, + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms) { if (!useDelayedPrivatization) { cloneSymbol(sym); copyFirstPrivateSymbol(sym); @@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) { return result; }(); - delayedPrivatizationInfo.privatizers.push_back( - mlir::SymbolRefAttr::get(privatizerOp)); - delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr()); - delayedPrivatizationInfo.symbols.push_back(sym); + if (clauseOps) { + clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp)); + clauseOps->privateVars.push_back(hsb.getAddr()); + } + + if (privateSyms) + privateSyms->push_back(sym); } } // namespace omp diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index 226abe96705e3..9724b3d5ed02f 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -19,28 +19,17 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" +namespace mlir { +namespace omp { +struct PrivateClauseOps; +} // namespace omp +} // namespace mlir + namespace Fortran { namespace lower { namespace omp { class DataSharingProcessor { -public: - /// Collects all the information needed for delayed privatization. This can be - /// used by ops with data-sharing clauses to properly generate their regions - /// (e.g. add region arguments) and map the original SSA values to their - /// corresponding OMP region operands. - struct DelayedPrivatizationInfo { - // The list of symbols referring to delayed privatizer ops (i.e. - // `omp.private` ops). - llvm::SmallVector privatizers; - // SSA values that correspond to "original" values being privatized. - // "Original" here means the SSA value outside the OpenMP region from which - // a clone is created inside the region. - llvm::SmallVector originalAddresses; - // Fortran symbols corresponding to the above SSA values. - llvm::SmallVector symbols; - }; - private: bool hasLastPrivateOp; mlir::OpBuilder::InsertPoint lastPrivIP; @@ -57,7 +46,6 @@ class DataSharingProcessor { Fortran::lower::pft::Evaluation &eval; bool useDelayedPrivatization; Fortran::lower::SymMap *symTable; - DelayedPrivatizationInfo delayedPrivatizationInfo; bool needBarrier(); void collectSymbols(Fortran::semantics::Symbol::Flag flag); @@ -67,9 +55,16 @@ class DataSharingProcessor { void collectSymbolsForPrivatization(); void insertBarrier(); void collectDefaultSymbols(); - void privatize(); - void defaultPrivatize(); - void doPrivatize(const Fortran::semantics::Symbol *sym); + void privatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); + void defaultPrivatize( + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); + void doPrivatize( + const Fortran::semantics::Symbol *sym, + mlir::omp::PrivateClauseOps *clauseOps, + llvm::SmallVectorImpl *privateSyms); void copyLastPrivatize(mlir::Operation *op); void insertLastPrivateCompare(mlir::Operation *op); void cloneSymbol(const Fortran::semantics::Symbol *sym); @@ -103,17 +98,15 @@ class DataSharingProcessor { // Step2 performs the copying for lastprivates and requires knowledge of the // MLIR operation to insert the last private update. Step2 adds // dealocation code as well. - void processStep1(); + void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr, + llvm::SmallVectorImpl + *privateSyms = nullptr); void processStep2(mlir::Operation *op, bool isLoop); void setLoopIV(mlir::Value iv) { assert(!loopIV && "Loop iteration variable already set"); loopIV = iv; } - - const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const { - return delayedPrivatizationInfo; - } }; } // namespace omp diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0cf2a8f97040a..d67060d1cce72 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -523,19 +523,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation) { return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested), - /*resultTypes=*/mlir::TypeRange()); + .setGenNested(genNested)); } static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation) { + mlir::Location currentLocation, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::OrderedRegionClauseOps clauseOps; + + ClauseProcessor cp(converter, semaCtx, clauseList); + cp.processTODO(currentLocation, + llvm::omp::Directive::OMPD_ordered); + return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested), - /*simd=*/false); + clauseOps); } static mlir::omp::ParallelOp @@ -546,77 +552,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList, bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, numThreadsClauseOperand; - mlir::omp::ClauseProcBindKindAttr procBindKindAttr; - llvm::SmallVector allocateOperands, allocatorOperands, - reductionVars; + mlir::omp::ParallelClauseOps clauseOps; + llvm::SmallVector privateSyms; llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; + llvm::SmallVector reductionSyms; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand); - cp.processNumThreads(stmtCtx, numThreadsClauseOperand); - cp.processProcBind(procBindKindAttr); + cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); + cp.processNumThreads(stmtCtx, clauseOps); + cp.processProcBind(clauseOps); cp.processDefault(); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); + if (!outerCombined) - cp.processReduction(currentLocation, reductionVars, reductionTypes, - reductionDeclSymbols, &reductionSymbols); + cp.processReduction(currentLocation, clauseOps, &reductionTypes, + &reductionSyms); + + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); auto reductionCallback = [&](mlir::Operation *op) { - llvm::SmallVector locs(reductionVars.size(), + llvm::SmallVector locs(clauseOps.reductionVars.size(), currentLocation); - auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {}, - reductionTypes, locs); + auto *block = + firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs); for (auto [arg, prv] : - llvm::zip_equal(reductionSymbols, block->getArguments())) { + llvm::zip_equal(reductionSyms, block->getArguments())) { converter.bindSymbol(*arg, prv); } - return reductionSymbols; + return reductionSyms; }; - mlir::UnitAttr byrefAttr; - if (ReductionProcessor::doReductionByRef(reductionVars)) - byrefAttr = converter.getFirOpBuilder().getUnitAttr(); - OpWithBodyGenInfo genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) .setClauses(&clauseList) - .setReductions(&reductionSymbols, &reductionTypes) + .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(reductionCallback); - if (!enableDelayedPrivatization) { - return genOpWithBody( - genInfo, - /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, - numThreadsClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols), - procBindKindAttr, /*private_vars=*/llvm::SmallVector{}, - /*privatizers=*/nullptr, byrefAttr); - } + if (!enableDelayedPrivatization) + return genOpWithBody(genInfo, clauseOps); bool privatize = !outerCombined; DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, /*useDelayedPrivatization=*/true, &symTable); if (privatize) - dsp.processStep1(); - - const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo(); + dsp.processStep1(&clauseOps, &privateSyms); auto genRegionEntryCB = [&](mlir::Operation *op) { auto parallelOp = llvm::cast(op); - llvm::SmallVector reductionLocs(reductionVars.size(), - currentLocation); + llvm::SmallVector reductionLocs( + clauseOps.reductionVars.size(), currentLocation); mlir::OperandRange privateVars = parallelOp.getPrivateVars(); mlir::Region ®ion = parallelOp.getRegion(); @@ -631,12 +622,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::transform(privateVars, std::back_inserter(privateVarLocs), [](mlir::Value v) { return v.getLoc(); }); - converter.getFirOpBuilder().createBlock(®ion, /*insertPt=*/{}, - privateVarTypes, privateVarLocs); + firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, + privateVarLocs); llvm::SmallVector allSymbols = - reductionSymbols; - allSymbols.append(delayedPrivatizationInfo.symbols); + reductionSyms; + allSymbols.append(privateSyms); for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { converter.bindSymbol(*arg, prv); } @@ -646,26 +637,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, // TODO Merge with the reduction CB. genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); - - llvm::SmallVector privatizers( - delayedPrivatizationInfo.privatizers.begin(), - delayedPrivatizationInfo.privatizers.end()); - - return genOpWithBody( - genInfo, - /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, - numThreadsClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols), - procBindKindAttr, delayedPrivatizationInfo.originalAddresses, - delayedPrivatizationInfo.privatizers.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - privatizers), - byrefAttr); + return genOpWithBody(genInfo, clauseOps); } static mlir::omp::SectionOp @@ -689,28 +661,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList &endClauseList) { - llvm::SmallVector allocateOperands, allocatorOperands; - llvm::SmallVector copyPrivateVars; - llvm::SmallVector copyPrivateFuncs; - mlir::UnitAttr nowaitAttr; + mlir::omp::SingleClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); + // TODO Support delayed privatization. ClauseProcessor ecp(converter, semaCtx, endClauseList); - ecp.processNowait(nowaitAttr); - ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs); + ecp.processNowait(clauseOps); + ecp.processCopyprivate(currentLocation, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&beginClauseList), - allocateOperands, allocatorOperands, copyPrivateVars, - copyPrivateFuncs.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - copyPrivateFuncs), - nowaitAttr); + clauseOps); } static mlir::omp::TaskOp @@ -720,21 +685,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand; - mlir::UnitAttr untiedAttr, mergeableAttr; - llvm::SmallVector dependTypeOperands; - llvm::SmallVector allocateOperands, allocatorOperands, - dependOperands; + mlir::omp::TaskClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); + cp.processAllocate(clauseOps); cp.processDefault(); - cp.processFinal(stmtCtx, finalClauseOperand); - cp.processUntied(untiedAttr); - cp.processMergeable(mergeableAttr); - cp.processPriority(stmtCtx, priorityClauseOperand); - cp.processDepend(dependTypeOperands, dependOperands); + cp.processFinal(stmtCtx, clauseOps); + cp.processUntied(clauseOps); + cp.processMergeable(clauseOps); + cp.processPriority(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + // TODO Support delayed privatization. + cp.processTODO( currentLocation, llvm::omp::Directive::OMPD_task); @@ -742,14 +705,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), - ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, - /*in_reduction_vars=*/mlir::ValueRange(), - /*in_reductions=*/nullptr, priorityClauseOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, allocateOperands, allocatorOperands); + clauseOps); } static mlir::omp::TaskgroupOp @@ -758,17 +714,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { - llvm::SmallVector allocateOperands, allocatorOperands; + mlir::omp::TaskgroupClauseOps clauseOps; + ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processAllocate(clauseOps); cp.processTODO(currentLocation, llvm::omp::Directive::OMPD_taskgroup); + return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), - /*task_reduction_vars=*/mlir::ValueRange(), - /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); + clauseOps); } // This helper function implements the functionality of "promoting" @@ -789,8 +746,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter, // clause. Support for such list items in a use_device_ptr clause // is deprecated." static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - llvm::SmallVectorImpl &devicePtrOperands, - llvm::SmallVectorImpl &deviceAddrOperands, + mlir::omp::UseDeviceClauseOps &clauseOps, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl @@ -803,9 +759,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( // Iterate over our use_device_ptr list and shift all non-cptr arguments into // use_device_addr. - for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) { + for (auto *it = clauseOps.useDevicePtrVars.begin(); + it != clauseOps.useDevicePtrVars.end();) { if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - deviceAddrOperands.push_back(*it); + clauseOps.useDeviceAddrVars.push_back(*it); // We have to shuffle the symbols around as well, to maintain // the correct Input -> BlockArg for use_device_ptr/use_device_addr. // NOTE: However, as map's do not seem to be included currently @@ -813,11 +770,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( // future alterations. I believe the reason they are not currently // is that the BlockArg assign/lowering needs to be extended // to a greater set of types. - auto idx = std::distance(devicePtrOperands.begin(), it); + auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); moveElementToBack(idx, useDeviceTypes); moveElementToBack(idx, useDeviceLocs); moveElementToBack(idx, useDeviceSymbols); - it = devicePtrOperands.erase(it); + it = clauseOps.useDevicePtrVars.erase(it); continue; } ++it; @@ -831,20 +788,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand; - llvm::SmallVector mapOperands, devicePtrOperands, - deviceAddrOperands; + mlir::omp::TargetDataClauseOps clauseOps; llvm::SmallVector useDeviceTypes; llvm::SmallVector useDeviceLocs; - llvm::SmallVector useDeviceSymbols; + llvm::SmallVector useDeviceSyms; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); - cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); + cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy @@ -856,17 +812,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, // ordering. // TODO: Perhaps create a user provideable compiler option that will // re-introduce a hard-error rather than a warning in these cases. - promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs, - useDeviceSymbols); + promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, + useDeviceLocs, useDeviceSyms); cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data, - stmtCtx, mapOperands); + stmtCtx, clauseOps); auto dataOp = converter.getFirOpBuilder().create( - currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, - deviceAddrOperands, mapOperands); + currentLocation, clauseOps); + genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp, - useDeviceTypes, useDeviceLocs, useDeviceSymbols, + useDeviceTypes, useDeviceLocs, useDeviceSyms, currentLocation); return dataOp; } @@ -879,10 +834,7 @@ static OpTy genTargetEnterExitDataUpdateOp( const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand; - mlir::UnitAttr nowaitAttr; - llvm::SmallVector mapOperands, dependOperands; - llvm::SmallVector dependTypeOperands; + mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. [[maybe_unused]] llvm::omp::Directive directive; @@ -897,25 +849,19 @@ static OpTy genTargetEnterExitDataUpdateOp( } ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(directive, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processDepend(dependTypeOperands, dependOperands); - cp.processNowait(nowaitAttr); + cp.processIf(directive, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + cp.processNowait(clauseOps); if constexpr (std::is_same_v) { - cp.processMotionClauses(stmtCtx, mapOperands); - cp.processMotionClauses(stmtCtx, mapOperands); + cp.processMotionClauses(stmtCtx, clauseOps); + cp.processMotionClauses(stmtCtx, clauseOps); } else { - cp.processMap(currentLocation, directive, stmtCtx, mapOperands); + cp.processMap(currentLocation, directive, stmtCtx, clauseOps); } - return firOpBuilder.create( - currentLocation, ifClauseOperand, deviceOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, nowaitAttr, mapOperands); + return firOpBuilder.create(currentLocation, clauseOps); } // This functions creates a block for the body of the targetOp's region. It adds @@ -925,9 +871,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::omp::TargetOp &targetOp, - llvm::ArrayRef mapSymTypes, + llvm::ArrayRef mapSyms, llvm::ArrayRef mapSymLocs, - llvm::ArrayRef mapSymbols, + llvm::ArrayRef mapSymTypes, const mlir::Location ¤tLocation) { assert(mapSymTypes.size() == mapSymLocs.size()); @@ -956,7 +902,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, }; // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) { + for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { const mlir::BlockArgument &arg = region.getArgument(argIndex); // Avoid capture of a reference to a structured binding. const Fortran::semantics::Symbol *sym = argSymbol; @@ -1080,22 +1026,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &clauseList, llvm::omp::Directive directive, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand; - mlir::UnitAttr nowaitAttr; - llvm::SmallVector dependTypeOperands; - llvm::SmallVector mapOperands, dependOperands; - llvm::SmallVector mapSymTypes; + mlir::omp::TargetClauseOps clauseOps; + llvm::SmallVector mapSyms; llvm::SmallVector mapSymLocs; - llvm::SmallVector mapSymbols; + llvm::SmallVector mapSymTypes; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand); - cp.processDevice(stmtCtx, deviceOperand); - cp.processThreadLimit(stmtCtx, threadLimitOperand); - cp.processDepend(dependTypeOperands, dependOperands); - cp.processNowait(nowaitAttr); - cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes, - &mapSymLocs, &mapSymbols); + cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + cp.processDepend(clauseOps); + cp.processNowait(clauseOps); + cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms, + &mapSymLocs, &mapSymTypes); + // TODO Support delayed privatization. cp.processTODO( - currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand, - dependTypeOperands.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - dependTypeOperands), - dependOperands, nowaitAttr, mapOperands); + currentLocation, clauseOps); - genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes, - mapSymLocs, mapSymbols, currentLocation); + genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms, + mapSymLocs, mapSymTypes, currentLocation); return targetOp; } @@ -1209,17 +1148,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpClauseList &clauseList, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand; - llvm::SmallVector allocateOperands, allocatorOperands, - reductionVars; - llvm::SmallVector reductionDeclSymbols; + mlir::omp::TeamsClauseOps clauseOps; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); + cp.processAllocate(clauseOps); cp.processDefault(); - cp.processNumTeams(stmtCtx, numTeamsClauseOperand); - cp.processThreadLimit(stmtCtx, threadLimitClauseOperand); + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. + cp.processTODO(currentLocation, llvm::omp::Directive::OMPD_teams); @@ -1228,30 +1166,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, .setGenNested(genNested) .setOuterCombined(outerCombined) .setClauses(&clauseList), - /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand, - threadLimitClauseOperand, allocateOperands, allocatorOperands, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(), - reductionDeclSymbols)); + clauseOps); } /// Extract the list of function and variable symbols affected by the given /// 'declare target' directive and return the intended device type for them. -static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( +static void getDeclareTargetInfo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + mlir::omp::DeclareTargetClauseOps &clauseOps, llvm::SmallVectorImpl &symbolAndClause) { - - // The default capture type - mlir::omp::DeclareTargetDeviceType deviceType = - mlir::omp::DeclareTargetDeviceType::any; const auto &spec = std::get( declareTargetConstruct.t); - if (const auto *objectList{ Fortran::parser::Unwrap(spec.u)}) { ObjectList objects{makeList(*objectList, semaCtx)}; @@ -1272,12 +1200,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( cp.processTo(symbolAndClause); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); - cp.processDeviceType(deviceType); + cp.processDeviceType(clauseOps); cp.processTODO(converter.getCurrentLocation(), llvm::omp::Directive::OMPD_declare_target); } - - return deviceType; } static void collectDeferredDeclareTargets( @@ -1287,9 +1213,10 @@ static void collectDeferredDeclareTargets( const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, llvm::SmallVectorImpl &deferredDeclareTarget) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; - mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); @@ -1299,8 +1226,9 @@ static void collectDeferredDeclareTargets( std::get(symClause))); if (!op) { - deferredDeclareTarget.push_back( - {std::get<0>(symClause), devType, std::get<1>(symClause)}); + deferredDeclareTarget.push_back({std::get<0>(symClause), + clauseOps.deviceType, + std::get<1>(symClause)}); } } } @@ -1312,9 +1240,10 @@ getDeclareTargetFunctionDevice( Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; - mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine @@ -1324,7 +1253,7 @@ getDeclareTargetFunctionDevice( std::get(symClause))); if (mlir::isa_and_nonnull(op)) - return deviceType; + return clauseOps.deviceType; } return std::nullopt; @@ -1354,12 +1283,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_barrier: firOpBuilder.create(currentLocation); break; - case llvm::omp::Directive::OMPD_taskwait: - ClauseProcessor(converter, semaCtx, opClauseList) - .processTODO( - currentLocation, llvm::omp::Directive::OMPD_taskwait); - firOpBuilder.create(currentLocation); + case llvm::omp::Directive::OMPD_taskwait: { + mlir::omp::TaskwaitClauseOps clauseOps; + ClauseProcessor cp(converter, semaCtx, opClauseList); + cp.processTODO( + currentLocation, llvm::omp::Directive::OMPD_taskwait); + firOpBuilder.create(currentLocation, clauseOps); break; + } case llvm::omp::Directive::OMPD_taskyield: firOpBuilder.create(currentLocation); break; @@ -1494,32 +1425,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; - mlir::Value scheduleChunkClauseOperand, ifClauseOperand; - llvm::SmallVector lowerBound, upperBound, step, reductionVars; - llvm::SmallVector alignedVars, nontemporalVars; + mlir::omp::SimdLoopClauseOps clauseOps; llvm::SmallVector iv; - llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - mlir::omp::ClauseOrderKindAttr orderClauseOperand; - mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand; ClauseProcessor cp(converter, semaCtx, loopOpClauseList); - cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv); - cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); - cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols); - cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand); - cp.processSimdlen(simdlenClauseOperand); - cp.processSafelen(safelenClauseOperand); + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processReduction(loc, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); + cp.processSimdlen(clauseOps); + cp.processSafelen(clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. + cp.processTODO(loc, ompDirective); - mlir::TypeRange resultType; - auto simdLoopOp = firOpBuilder.create( - loc, resultType, lowerBound, upperBound, step, alignedVars, - /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars, - orderClauseOperand, simdlenClauseOperand, safelenClauseOperand, - /*inclusive=*/firOpBuilder.getUnitAttr()); - auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(loopOpClauseList)); @@ -1527,11 +1447,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, return genLoopVars(op, converter, loc, iv); }; - createBodyOfOp( - simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&loopOpClauseList) - .setDataSharingProcessor(&dsp) - .setGenRegionEntryCb(ivCallback)); + genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&loopOpClauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback), + clauseOps); } static void createWsloop(Fortran::lower::AbstractConverter &converter, @@ -1546,77 +1467,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; - mlir::Value scheduleChunkClauseOperand; - llvm::SmallVector lowerBound, upperBound, step, reductionVars; - llvm::SmallVector linearVars, linearStepVars; + mlir::omp::WsloopClauseOps clauseOps; llvm::SmallVector iv; llvm::SmallVector reductionTypes; - llvm::SmallVector reductionDeclSymbols; - llvm::SmallVector reductionSymbols; - mlir::omp::ClauseOrderKindAttr orderClauseOperand; - mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand; - mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand; - mlir::IntegerAttr orderedClauseOperand; - mlir::omp::ScheduleModifierAttr scheduleModClauseOperand; + llvm::SmallVector reductionSyms; ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv); - cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); - cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols, - &reductionSymbols); - cp.processTODO(loc, ompDirective); - - if (ReductionProcessor::doReductionByRef(reductionVars)) - byrefOperand = firOpBuilder.getUnitAttr(); - - auto wsLoopOp = firOpBuilder.create( - loc, lowerBound, upperBound, step, linearVars, linearStepVars, - reductionVars, - reductionDeclSymbols.empty() - ? nullptr - : mlir::ArrayAttr::get(firOpBuilder.getContext(), - reductionDeclSymbols), - scheduleValClauseOperand, scheduleChunkClauseOperand, - /*schedule_modifiers=*/nullptr, - /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand, - orderedClauseOperand, orderClauseOperand, - /*inclusive=*/firOpBuilder.getUnitAttr()); - - // Handle attribute based clauses. - if (cp.processOrdered(orderedClauseOperand)) - wsLoopOp.setOrderedValAttr(orderedClauseOperand); - - if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand, - scheduleSimdClauseOperand)) { - wsLoopOp.setScheduleValAttr(scheduleValClauseOperand); - wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand); - wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand); - } + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processSchedule(stmtCtx, clauseOps); + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + cp.processOrdered(clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. + + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); + + cp.processTODO(loc, + ompDirective); + // In FORTRAN `nowait` clause occur at the end of `omp do` directive. // i.e // !$omp do // <...> // !$omp end do nowait if (endClauseList) { - if (ClauseProcessor(converter, semaCtx, *endClauseList) - .processNowait(nowaitClauseOperand)) - wsLoopOp.setNowaitAttr(nowaitClauseOperand); + ClauseProcessor ecp(converter, semaCtx, *endClauseList); + ecp.processNowait(clauseOps); } auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(beginClauseList)); auto ivCallback = [&](mlir::Operation *op) { - return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols, + return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms, reductionTypes); }; - createBodyOfOp( - wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&beginClauseList) - .setDataSharingProcessor(&dsp) - .setReductions(&reductionSymbols, &reductionTypes) - .setGenRegionEntryCb(ivCallback)); + genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&beginClauseList) + .setDataSharingProcessor(&dsp) + .setReductions(&reductionSyms, &reductionTypes) + .setGenRegionEntryCb(ivCallback), + clauseOps); } static void createSimdWsloop( @@ -1704,10 +1598,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; llvm::SmallVector symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); for (const DeclareTargetCapturePair &symClause : symbolAndClause) { mlir::Operation *op = mod.lookupSymbol(converter.mangleName( @@ -1721,7 +1616,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, markDeclareTarget( op, converter, - std::get(symClause), deviceType); + std::get(symClause), + clauseOps.deviceType); } } @@ -1853,7 +1749,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, !std::get_if(&clause.u) && !std::get_if(&clause.u) && !std::get_if(&clause.u) && - !std::get_if(&clause.u)) { + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } @@ -1873,7 +1770,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_ordered: genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation); + currentLocation, beginClauseList); break; case llvm::omp::Directive::OMPD_parallel: genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true, @@ -1964,7 +1861,6 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); - mlir::IntegerAttr hintClauseOp; std::string name; const Fortran::parser::OmpCriticalDirective &cd = std::get(criticalConstruct.t); @@ -1973,21 +1869,28 @@ genOMP(Fortran::lower::AbstractConverter &converter, std::get>(cd.t).value().ToString(); } - const auto &clauseList = std::get(cd.t); - ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp); - mlir::omp::CriticalOp criticalOp = [&]() { if (name.empty()) { return firOpBuilder.create( currentLocation, mlir::FlatSymbolRefAttr()); } + mlir::ModuleOp module = firOpBuilder.getModule(); mlir::OpBuilder modBuilder(module.getBodyRegion()); auto global = module.lookupSymbol(name); - if (!global) - global = modBuilder.create( - currentLocation, - mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp); + if (!global) { + mlir::omp::CriticalClauseOps clauseOps; + const auto &clauseList = std::get(cd.t); + + ClauseProcessor cp(converter, semaCtx, clauseList); + cp.processHint(clauseOps); + clauseOps.nameAttr = + mlir::StringAttr::get(firOpBuilder.getContext(), name); + + global = modBuilder.create(currentLocation, + clauseOps); + } + return firOpBuilder.create( currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), global.getSymName())); @@ -2104,8 +2007,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { mlir::Location currentLocation = converter.getCurrentLocation(); - llvm::SmallVector allocateOperands, allocatorOperands; - mlir::UnitAttr nowaitClauseOperand; + mlir::omp::SectionsClauseOps clauseOps; const auto &beginSectionsDirective = std::get(sectionsConstruct.t); const auto §ionsClauseList = @@ -2114,8 +2016,9 @@ genOMP(Fortran::lower::AbstractConverter &converter, // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region ClauseProcessor cp(converter, semaCtx, sectionsClauseList); - cp.processSectionsReduction(currentLocation); - cp.processAllocate(allocatorOperands, allocateOperands); + cp.processSectionsReduction(currentLocation, clauseOps); + cp.processAllocate(clauseOps); + // TODO Support delayed privatization. llvm::omp::Directive dir = std::get(beginSectionsDirective.t) @@ -2132,16 +2035,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, const auto &endSectionsClauseList = std::get(endSectionsDirective.t); ClauseProcessor(converter, semaCtx, endSectionsClauseList) - .processNowait(nowaitClauseOperand); + .processNowait(clauseOps); } // SECTIONS construct genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(false), - /*reduction_vars=*/mlir::ValueRange(), - /*reductions=*/nullptr, allocateOperands, allocatorOperands, - nowaitClauseOperand); + clauseOps); const auto §ionBlocks = std::get(sectionsConstruct.t); From e291fad68b78d28bfa73caab94ddcb978db2a602 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Thu, 28 Mar 2024 15:14:37 +0000 Subject: [PATCH 03/13] [Flang][OpenMP][Lower] Split MLIR codegen for clauses and constructs This patch performs several cleanups with the main purpose of normalizing the code patterns used to trigger codegen for MLIR OpenMP operations and making the processing of clauses and constructs independent. The following changes are made: - Clean up unused `directive` argument to `ClauseProcessor::processMap()`. - Move general helper functions in OpenMP.cpp to the appropriate section of the file. - Create `genClauses()` functions containing the clause processing code specific for the associated OpenMP construct. - Update `genOp()` functions to call the corresponding `genClauses()` function. - Sort calls to `ClauseProcessor::process()` alphabetically, to avoid inadvertently relying on some arbitrary order. Update some tests that broke due to the order change. - Normalize `genOMP()` functions so they all delegate the generation of MLIR to `genOp()` functions following the same pattern. - Only process `nowait` clause on `TARGET` constructs if not compiling for the target device. A later patch can move the calls to `genClauses()` out of `genOp()` functions and passing completed clause structures instead, in preparation to supporting composite constructs. That will make it possible to reuse clause processing for a given leaf construct when appearing alone or in a combined or composite construct, while controlling where the associated code is produced. --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +- flang/lib/Lower/OpenMP/ClauseProcessor.h | 3 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 2090 +++++++++-------- flang/test/Lower/OpenMP/FIR/target.f90 | 2 +- flang/test/Lower/OpenMP/target.f90 | 2 +- .../use-device-ptr-to-use-device-addr.f90 | 4 +- 6 files changed, 1173 insertions(+), 932 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index ee1f6c2fbc7e8..e2b26b3025049 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -804,8 +804,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, } bool ClauseProcessor::processMap( - mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, + mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx, + mlir::omp::MapClauseOps &result, llvm::SmallVectorImpl *mapSyms, llvm::SmallVectorImpl *mapSymLocs, llvm::SmallVectorImpl *mapSymTypes) const { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index d933e0a913d2b..9e59d754280ef 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -102,8 +102,7 @@ class ClauseProcessor { // They may be used later on to create the block_arguments for some of the // target directives that require it. bool processMap( - mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::lower::StatementContext &stmtCtx, + mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result, llvm::SmallVectorImpl *mapSyms = nullptr, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index d67060d1cce72..b6de2079a973f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -237,6 +237,276 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter, return storeOp; } +// This helper function implements the functionality of "promoting" +// non-CPTR arguments of use_device_ptr to use_device_addr +// arguments (automagic conversion of use_device_ptr -> +// use_device_addr in these cases). The way we do so currently is +// through the shuffling of operands from the devicePtrOperands to +// deviceAddrOperands where neccesary and re-organizing the types, +// locations and symbols to maintain the correct ordering of ptr/addr +// input -> BlockArg. +// +// This effectively implements some deprecated OpenMP functionality +// that some legacy applications unfortunately depend on +// (deprecated in specification version 5.2): +// +// "If a list item in a use_device_ptr clause is not of type C_PTR, +// the behavior is as if the list item appeared in a use_device_addr +// clause. Support for such list items in a use_device_ptr clause +// is deprecated." +static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( + mlir::omp::UseDeviceClauseOps &clauseOps, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl + &useDeviceSymbols) { + auto moveElementToBack = [](size_t idx, auto &vector) { + auto *iter = std::next(vector.begin(), idx); + vector.push_back(*iter); + vector.erase(iter); + }; + + // Iterate over our use_device_ptr list and shift all non-cptr arguments into + // use_device_addr. + for (auto *it = clauseOps.useDevicePtrVars.begin(); + it != clauseOps.useDevicePtrVars.end();) { + if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { + clauseOps.useDeviceAddrVars.push_back(*it); + // We have to shuffle the symbols around as well, to maintain + // the correct Input -> BlockArg for use_device_ptr/use_device_addr. + // NOTE: However, as map's do not seem to be included currently + // this isn't as pertinent, but we must try to maintain for + // future alterations. I believe the reason they are not currently + // is that the BlockArg assign/lowering needs to be extended + // to a greater set of types. + auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); + moveElementToBack(idx, useDeviceTypes); + moveElementToBack(idx, useDeviceLocs); + moveElementToBack(idx, useDeviceSymbols); + it = clauseOps.useDevicePtrVars.erase(it); + continue; + } + ++it; + } +} + +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static void getDeclareTargetInfo( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + mlir::omp::DeclareTargetClauseOps &clauseOps, + llvm::SmallVectorImpl &symbolAndClause) { + const auto &spec = std::get( + declareTargetConstruct.t); + if (const auto *objectList{ + Fortran::parser::Unwrap(spec.u)}) { + ObjectList objects{makeList(*objectList, semaCtx)}; + // Case: declare target(func, var1, var2) + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, + symbolAndClause); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbolAndClause.emplace_back( + mlir::omp::DeclareTargetCaptureClause::to, + eval.getOwningProcedure()->getSubprogramSymbol()); + } + + ClauseProcessor cp(converter, semaCtx, *clauseList); + cp.processDeviceType(clauseOps); + cp.processEnter(symbolAndClause); + cp.processLink(symbolAndClause); + cp.processTo(symbolAndClause); + cp.processTODO(converter.getCurrentLocation(), + llvm::omp::Directive::OMPD_declare_target); + } +} + +static void collectDeferredDeclareTargets( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + llvm::SmallVectorImpl + &deferredDeclareTarget) { + mlir::omp::DeclareTargetClauseOps clauseOps; + llvm::SmallVector symbolAndClause; + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol(converter.mangleName( + std::get(symClause))); + + if (!op) { + deferredDeclareTarget.push_back({std::get<0>(symClause), + clauseOps.deviceType, + std::get<1>(symClause)}); + } + } +} + +static std::optional +getDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; + llvm::SmallVector symbolAndClause; + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol(converter.mangleName( + std::get(symClause))); + + if (mlir::isa_and_nonnull(op)) + return clauseOps.deviceType; + } + + return std::nullopt; +} + +static llvm::SmallVector +genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef args) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + auto ®ion = op->getRegion(0); + + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : args) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + llvm::SmallVector tiv(args.size(), loopVarType); + llvm::SmallVector locs(args.size(), loc); + firOpBuilder.createBlock(®ion, {}, tiv, locs); + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + mlir::Operation *storeOp = nullptr; + for (auto [argIndex, argSymbol] : llvm::enumerate(args)) { + mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex)); + storeOp = + createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); + } + firOpBuilder.setInsertionPointAfter(storeOp); + + return llvm::SmallVector(args); +} + +static void genReductionVars( + mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef reductionArgs, + llvm::ArrayRef reductionTypes) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + llvm::SmallVector blockArgLocs(reductionArgs.size(), loc); + + mlir::Block *entryBlock = firOpBuilder.createBlock( + &op->getRegion(0), {}, reductionTypes, blockArgLocs); + + // Bind the reduction arguments to their block arguments. + for (auto [arg, prv] : + llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { + converter.bindSymbol(*arg, prv); + } +} + +static llvm::SmallVector +genLoopAndReductionVars( + mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef loopArgs, + llvm::ArrayRef reductionArgs, + llvm::ArrayRef reductionTypes) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + llvm::SmallVector blockArgTypes; + llvm::SmallVector blockArgLocs; + blockArgTypes.reserve(loopArgs.size() + reductionArgs.size()); + blockArgLocs.reserve(blockArgTypes.size()); + mlir::Block *entryBlock; + + if (loopArgs.size()) { + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : loopArgs) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(), + loopVarType); + std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc); + } + if (reductionArgs.size()) { + llvm::copy(reductionTypes, std::back_inserter(blockArgTypes)); + std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc); + } + entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes, + blockArgLocs); + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + if (loopArgs.size()) { + mlir::Operation *storeOp = nullptr; + for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) { + mlir::Value indexVal = + fir::getBase(op->getRegion(0).front().getArgument(argIndex)); + storeOp = + createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); + } + firOpBuilder.setInsertionPointAfter(storeOp); + } + // Bind the reduction arguments to their block arguments + for (auto [arg, prv] : llvm::zip_equal( + reductionArgs, + llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) { + converter.bindSymbol(*arg, prv); + } + + return llvm::SmallVector(loopArgs); +} + +static void +markDeclareTarget(mlir::Operation *op, + Fortran::lower::AbstractConverter &converter, + mlir::omp::DeclareTargetCaptureClause captureClause, + mlir::omp::DeclareTargetDeviceType deviceType) { + // TODO: Add support for program local variables with declare target applied + auto declareTargetOp = llvm::dyn_cast(op); + if (!declareTargetOp) + fir::emitFatalError( + converter.getCurrentLocation(), + "Attempt to apply declare target on unsupported operation"); + + // The function or global already has a declare target applied to it, very + // likely through implicit capture (usage in another declare target + // function/subroutine). It should be marked as any if it has been assigned + // both host and nohost, else we skip, as there is no change + if (declareTargetOp.isDeclareTarget()) { + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) + declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, + captureClause); + return; + } + + declareTargetOp.setDeclareTarget(deviceType, captureClause); +} + +//===----------------------------------------------------------------------===// +// Op body generation helper structures and functions +//===----------------------------------------------------------------------===// + struct OpWithBodyGenInfo { /// A type for a code-gen callback function. This takes as argument the op for /// which the code is being generated and returns the arguments of the op's @@ -508,543 +778,726 @@ static void genBodyOfTargetDataOp( genNestedEvaluations(converter, eval); } -template -static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { - auto op = info.converter.getFirOpBuilder().create( - info.loc, std::forward(args)...); - createBodyOfOp(op, info); - return op; -} - -static mlir::omp::MasterOp -genMasterOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation) { - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested)); -} - -static mlir::omp::OrderedRegionOp -genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - mlir::omp::OrderedRegionClauseOps clauseOps; +// This functions creates a block for the body of the targetOp's region. It adds +// all the symbols present in mapSymbols as block arguments to this block. +static void +genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::omp::TargetOp &targetOp, + llvm::ArrayRef mapSyms, + llvm::ArrayRef mapSymLocs, + llvm::ArrayRef mapSymTypes, + const mlir::Location ¤tLocation) { + assert(mapSymTypes.size() == mapSymLocs.size()); - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processTODO(currentLocation, - llvm::omp::Directive::OMPD_ordered); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Region ®ion = targetOp.getRegion(); - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested), - clauseOps); -} + auto *regionBlock = + firOpBuilder.createBlock(®ion, {}, mapSymTypes, mapSymLocs); -static mlir::omp::ParallelOp -genParallelOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList, - bool outerCombined = false) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - Fortran::lower::StatementContext stmtCtx; - mlir::omp::ParallelClauseOps clauseOps; - llvm::SmallVector privateSyms; - llvm::SmallVector reductionTypes; - llvm::SmallVector reductionSyms; + // Clones the `bounds` placing them inside the target region and returns them. + auto cloneBound = [&](mlir::Value bound) { + if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { + mlir::Operation *clonedOp = bound.getDefiningOp()->clone(); + regionBlock->push_back(clonedOp); + return clonedOp->getResult(0); + } + TODO(converter.getCurrentLocation(), + "target map clause operand unsupported bound type"); + }; - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - cp.processNumThreads(stmtCtx, clauseOps); - cp.processProcBind(clauseOps); - cp.processDefault(); - cp.processAllocate(clauseOps); + auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { + llvm::SmallVector clonedBounds; + for (mlir::Value bound : bounds) + clonedBounds.emplace_back(cloneBound(bound)); + return clonedBounds; + }; - if (!outerCombined) - cp.processReduction(currentLocation, clauseOps, &reductionTypes, - &reductionSyms); + // Bind the symbols to their corresponding block arguments. + for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { + const mlir::BlockArgument &arg = region.getArgument(argIndex); + // Avoid capture of a reference to a structured binding. + const Fortran::semantics::Symbol *sym = argSymbol; + // Structure component symbols don't have bindings. + if (sym->owner().IsDerivedType()) + continue; + fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); + extVal.match( + [&](const fir::BoxValue &v) { + converter.bindSymbol(*sym, + fir::BoxValue(arg, cloneBounds(v.getLBounds()), + v.getExplicitParameters(), + v.getExplicitExtents())); + }, + [&](const fir::MutableBoxValue &v) { + converter.bindSymbol( + *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), + v.getMutableProperties())); + }, + [&](const fir::ArrayBoxValue &v) { + converter.bindSymbol( + *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()), + v.getSourceBox())); + }, + [&](const fir::CharArrayBoxValue &v) { + converter.bindSymbol( + *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), + cloneBounds(v.getExtents()), + cloneBounds(v.getLBounds()))); + }, + [&](const fir::CharBoxValue &v) { + converter.bindSymbol(*sym, + fir::CharBoxValue(arg, cloneBound(v.getLen()))); + }, + [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); }, + [&](const auto &) { + TODO(converter.getCurrentLocation(), + "target map clause operand unsupported type"); + }); + } - if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) - clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); + // Check if cloning the bounds introduced any dependency on the outer region. + // If so, then either clone them as well if they are MemoryEffectFree, or else + // copy them to a new temporary and add them to the map and block_argument + // lists and replace their uses with the new temporary. + llvm::SetVector valuesDefinedAbove; + mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); + while (!valuesDefinedAbove.empty()) { + for (mlir::Value val : valuesDefinedAbove) { + mlir::Operation *valOp = val.getDefiningOp(); + if (mlir::isMemoryEffectFree(valOp)) { + mlir::Operation *clonedOp = valOp->clone(); + regionBlock->push_front(clonedOp); + val.replaceUsesWithIf( + clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == regionBlock; + }); + } else { + auto savedIP = firOpBuilder.getInsertionPoint(); + firOpBuilder.setInsertionPointAfter(valOp); + auto copyVal = + firOpBuilder.createTemporary(val.getLoc(), val.getType()); + firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal); - auto reductionCallback = [&](mlir::Operation *op) { - llvm::SmallVector locs(clauseOps.reductionVars.size(), - currentLocation); - auto *block = - firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs); - for (auto [arg, prv] : - llvm::zip_equal(reductionSyms, block->getArguments())) { - converter.bindSymbol(*arg, prv); + llvm::SmallVector bounds; + std::stringstream name; + firOpBuilder.setInsertionPoint(targetOp); + mlir::Value mapOp = createMapInfoOp( + firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(), + bounds, llvm::SmallVector{}, + static_cast< + std::underlying_type_t>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT), + mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType()); + targetOp.getMapOperandsMutable().append(mapOp); + mlir::Value clonedValArg = + region.addArgument(copyVal.getType(), copyVal.getLoc()); + firOpBuilder.setInsertionPointToStart(regionBlock); + auto loadOp = firOpBuilder.create(clonedValArg.getLoc(), + clonedValArg); + val.replaceUsesWithIf( + loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) { + return use.getOwner()->getBlock() == regionBlock; + }); + firOpBuilder.setInsertionPoint(regionBlock, savedIP); + } } - return reductionSyms; - }; - - OpWithBodyGenInfo genInfo = - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setOuterCombined(outerCombined) - .setClauses(&clauseList) - .setReductions(&reductionSyms, &reductionTypes) - .setGenRegionEntryCb(reductionCallback); + valuesDefinedAbove.clear(); + mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); + } - if (!enableDelayedPrivatization) - return genOpWithBody(genInfo, clauseOps); + // Insert dummy instruction to remember the insertion position. The + // marker will be deleted since there are not uses. + // In the HLFIR flow there are hlfir.declares inserted above while + // setting block arguments. + mlir::Value undefMarker = firOpBuilder.create( + targetOp.getOperation()->getLoc(), firOpBuilder.getIndexType()); - bool privatize = !outerCombined; - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, - /*useDelayedPrivatization=*/true, &symTable); + // Create blocks for unstructured regions. This has to be done since + // blocks are initially allocated with the function as the parent region. + if (eval.lowerAsUnstructured()) { + Fortran::lower::createEmptyRegionBlocks( + firOpBuilder, eval.getNestedEvaluations()); + } - if (privatize) - dsp.processStep1(&clauseOps, &privateSyms); + firOpBuilder.create(currentLocation); - auto genRegionEntryCB = [&](mlir::Operation *op) { - auto parallelOp = llvm::cast(op); + // Create the insertion point after the marker. + firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); + if (genNested) + genNestedEvaluations(converter, eval); +} - llvm::SmallVector reductionLocs( - clauseOps.reductionVars.size(), currentLocation); +template +static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { + auto op = info.converter.getFirOpBuilder().create( + info.loc, std::forward(args)...); + createBodyOfOp(op, info); + return op; +} - mlir::OperandRange privateVars = parallelOp.getPrivateVars(); - mlir::Region ®ion = parallelOp.getRegion(); +//===----------------------------------------------------------------------===// +// Code generation functions for clauses +//===----------------------------------------------------------------------===// - llvm::SmallVector privateVarTypes = reductionTypes; - privateVarTypes.reserve(privateVarTypes.size() + privateVars.size()); - llvm::transform(privateVars, std::back_inserter(privateVarTypes), - [](mlir::Value v) { return v.getType(); }); +static void genCriticalDeclareClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processHint(clauseOps); + clauseOps.nameAttr = + mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); +} - llvm::SmallVector privateVarLocs = reductionLocs; - privateVarLocs.reserve(privateVarLocs.size() + privateVars.size()); - llvm::transform(privateVars, std::back_inserter(privateVarLocs), - [](mlir::Value v) { return v.getLoc(); }); +static void genFlushClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const std::optional &objects, + const std::optional> + &clauses, + mlir::Location loc, llvm::SmallVectorImpl &operandRange) { + if (objects) + genObjectList2(*objects, converter, operandRange); + + if (clauses && clauses->size() > 0) + TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); +} - firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, - privateVarLocs); +static void +genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::OrderedRegionClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processTODO(loc, llvm::omp::Directive::OMPD_ordered); +} - llvm::SmallVector allSymbols = - reductionSyms; - allSymbols.append(privateSyms); - for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { - converter.bindSymbol(*arg, prv); - } +static void genParallelClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + llvm::SmallVectorImpl &reductionTypes, + llvm::SmallVectorImpl &reductionSyms) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processDefault(); + cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); + cp.processProcBind(clauseOps); - return allSymbols; - }; + if (processReduction) { + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr(); + } - // TODO Merge with the reduction CB. - genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); - return genOpWithBody(genInfo, clauseOps); + cp.processNumThreads(stmtCtx, clauseOps); } -static mlir::omp::SectionOp -genSectionOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList §ionsClauseList) { - // Currently only private/firstprivate clause is handled, and - // all privatization is done within `omp.section` operations. - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(§ionsClauseList)); +static void genSectionsClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + bool clausesFromBeginSections, + mlir::omp::SectionsClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + if (clausesFromBeginSections) { + cp.processAllocate(clauseOps); + cp.processSectionsReduction(loc, clauseOps); + // TODO Support delayed privatization. + } else { + cp.processNowait(clauseOps); + } } -static mlir::omp::SingleOp -genSingleOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList &endClauseList) { - mlir::omp::SingleClauseOps clauseOps; +static void genSimdLoopClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::SimdLoopClauseOps &clauseOps, + llvm::SmallVectorImpl &iv) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); + cp.processReduction(loc, clauseOps); + cp.processSafelen(clauseOps); + cp.processSimdlen(clauseOps); + clauseOps.loopInclusiveAttr = converter.getFirOpBuilder().getUnitAttr(); + // TODO Support delayed privatization. - ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processAllocate(clauseOps); + cp.processTODO( + loc, llvm::omp::Directive::OMPD_simd); +} + +static void genSingleClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &beginClauses, + const Fortran::parser::OmpClauseList &endClauses, + mlir::Location loc, + mlir::omp::SingleClauseOps &clauseOps) { + ClauseProcessor bcp(converter, semaCtx, beginClauses); + bcp.processAllocate(clauseOps); // TODO Support delayed privatization. - ClauseProcessor ecp(converter, semaCtx, endClauseList); + ClauseProcessor ecp(converter, semaCtx, endClauses); + ecp.processCopyprivate(loc, clauseOps); ecp.processNowait(clauseOps); - ecp.processCopyprivate(currentLocation, clauseOps); +} - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&beginClauseList), - clauseOps); +static void genTargetClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + bool processHostOnlyClauses, bool processReduction, + mlir::omp::TargetClauseOps &clauseOps, + llvm::SmallVectorImpl &mapSyms, + llvm::SmallVectorImpl &mapSymLocs, + llvm::SmallVectorImpl &mapSymTypes) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDepend(clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); + cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapSymLocs, &mapSymTypes); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. + + if (processHostOnlyClauses) + cp.processNowait(clauseOps); + + cp.processTODO(loc, + llvm::omp::Directive::OMPD_target); } -static mlir::omp::TaskOp -genTaskOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TaskClauseOps clauseOps; +static void genTargetDataClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::TargetDataClauseOps &clauseOps, + llvm::SmallVectorImpl &useDeviceTypes, + llvm::SmallVectorImpl &useDeviceLocs, + llvm::SmallVectorImpl &useDeviceSyms) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDevice(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); + cp.processMap(loc, stmtCtx, clauseOps); + cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); + // This function implements the deprecated functionality of use_device_ptr + // that allows users to provide non-CPTR arguments to it with the caveat + // that the compiler will treat them as use_device_addr. A lot of legacy + // code may still depend on this functionality, so we should support it + // in some manner. We do so currently by simply shifting non-cptr operands + // from the use_device_ptr list into the front of the use_device_addr list + // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and + // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg + // ordering. + // TODO: Perhaps create a user provideable compiler option that will + // re-introduce a hard-error rather than a warning in these cases. + promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, + useDeviceLocs, useDeviceSyms); +} + +static void genTargetEnterExitUpdateDataClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + llvm::omp::Directive directive, + mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDepend(clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processIf(directive, clauseOps); + cp.processNowait(clauseOps); + + if (directive == llvm::omp::Directive::OMPD_target_update) { + cp.processMotionClauses(stmtCtx, clauseOps); + cp.processMotionClauses(stmtCtx, clauseOps); + } else { + cp.processMap(loc, stmtCtx, clauseOps); + } +} + +static void genTaskClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processDefault(); + cp.processDepend(clauseOps); cp.processFinal(stmtCtx, clauseOps); - cp.processUntied(clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); cp.processMergeable(clauseOps); cp.processPriority(stmtCtx, clauseOps); - cp.processDepend(clauseOps); + cp.processUntied(clauseOps); // TODO Support delayed privatization. - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_task); + cp.processTODO( + loc, llvm::omp::Directive::OMPD_task); +} - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&clauseList), - clauseOps); +static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskgroupClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processTODO(loc, + llvm::omp::Directive::OMPD_taskgroup); } -static mlir::omp::TaskgroupOp -genTaskgroupOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - mlir::omp::TaskgroupClauseOps clauseOps; +static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskwaitClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processTODO( + loc, llvm::omp::Directive::OMPD_taskwait); +} - ClauseProcessor cp(converter, semaCtx, clauseList); +static void genTeamsClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TeamsClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); - cp.processTODO(currentLocation, - llvm::omp::Directive::OMPD_taskgroup); + cp.processDefault(); + cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&clauseList), - clauseOps); + cp.processTODO(loc, llvm::omp::Directive::OMPD_teams); } -// This helper function implements the functionality of "promoting" -// non-CPTR arguments of use_device_ptr to use_device_addr -// arguments (automagic conversion of use_device_ptr -> -// use_device_addr in these cases). The way we do so currently is -// through the shuffling of operands from the devicePtrOperands to -// deviceAddrOperands where neccesary and re-organizing the types, -// locations and symbols to maintain the correct ordering of ptr/addr -// input -> BlockArg. -// -// This effectively implements some deprecated OpenMP functionality -// that some legacy applications unfortunately depend on -// (deprecated in specification version 5.2): -// -// "If a list item in a use_device_ptr clause is not of type C_PTR, -// the behavior is as if the list item appeared in a use_device_addr -// clause. Support for such list items in a use_device_ptr clause -// is deprecated." -static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - mlir::omp::UseDeviceClauseOps &clauseOps, - llvm::SmallVectorImpl &useDeviceTypes, - llvm::SmallVectorImpl &useDeviceLocs, - llvm::SmallVectorImpl - &useDeviceSymbols) { - auto moveElementToBack = [](size_t idx, auto &vector) { - auto *iter = std::next(vector.begin(), idx); - vector.push_back(*iter); - vector.erase(iter); - }; +static void genWsloopClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauses, + const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc, + mlir::omp::WsloopClauseOps &clauseOps, + llvm::SmallVectorImpl &iv, + llvm::SmallVectorImpl &reductionTypes, + llvm::SmallVectorImpl &reductionSyms) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + ClauseProcessor bcp(converter, semaCtx, beginClauses); + bcp.processCollapse(loc, eval, clauseOps, iv); + bcp.processOrdered(clauseOps); + bcp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + bcp.processSchedule(stmtCtx, clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. - // Iterate over our use_device_ptr list and shift all non-cptr arguments into - // use_device_addr. - for (auto *it = clauseOps.useDevicePtrVars.begin(); - it != clauseOps.useDevicePtrVars.end();) { - if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - clauseOps.useDeviceAddrVars.push_back(*it); - // We have to shuffle the symbols around as well, to maintain - // the correct Input -> BlockArg for use_device_ptr/use_device_addr. - // NOTE: However, as map's do not seem to be included currently - // this isn't as pertinent, but we must try to maintain for - // future alterations. I believe the reason they are not currently - // is that the BlockArg assign/lowering needs to be extended - // to a greater set of types. - auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); - moveElementToBack(idx, useDeviceTypes); - moveElementToBack(idx, useDeviceLocs); - moveElementToBack(idx, useDeviceSymbols); - it = clauseOps.useDevicePtrVars.erase(it); - continue; + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); + + if (endClauses) { + ClauseProcessor ecp(converter, semaCtx, *endClauses); + ecp.processNowait(clauseOps); + } + + bcp.processTODO( + loc, llvm::omp::Directive::OMPD_do); +} + +//===----------------------------------------------------------------------===// +// Code generation functions for leaf constructs +//===----------------------------------------------------------------------===// + +static mlir::omp::BarrierOp +genBarrierOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc) { + return converter.getFirOpBuilder().create(loc); +} + +static mlir::omp::CriticalOp +genCriticalOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList, + const std::optional &name) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::FlatSymbolRefAttr nameAttr; + + if (name) { + std::string nameStr = name->ToString(); + mlir::ModuleOp mod = firOpBuilder.getModule(); + auto global = mod.lookupSymbol(nameStr); + if (!global) { + mlir::omp::CriticalClauseOps clauseOps; + genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps, + nameStr); + + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + global = modBuilder.create(loc, clauseOps); } - ++it; + nameAttr = mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), + global.getSymName()); } + + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), + nameAttr); } -static mlir::omp::TargetDataOp -genTargetDataOp(Fortran::lower::AbstractConverter &converter, +static mlir::omp::DistributeOp +genDistributeOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, + mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList) { - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TargetDataClauseOps clauseOps; - llvm::SmallVector useDeviceTypes; - llvm::SmallVector useDeviceLocs; - llvm::SmallVector useDeviceSyms; + TODO(loc, "Distribute construct"); + return nullptr; +} - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); +static mlir::omp::FlushOp +genFlushOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const std::optional &objectList, + const std::optional> + &clauseList) { + llvm::SmallVector operandRange; + genFlushClauses(converter, semaCtx, objectList, clauseList, loc, + operandRange); + + return converter.getFirOpBuilder().create( + converter.getCurrentLocation(), operandRange); +} - // This function implements the deprecated functionality of use_device_ptr - // that allows users to provide non-CPTR arguments to it with the caveat - // that the compiler will treat them as use_device_addr. A lot of legacy - // code may still depend on this functionality, so we should support it - // in some manner. We do so currently by simply shifting non-cptr operands - // from the use_device_ptr list into the front of the use_device_addr list - // whilst maintaining the ordering of useDeviceLocs, useDeviceSymbols and - // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg - // ordering. - // TODO: Perhaps create a user provideable compiler option that will - // re-introduce a hard-error rather than a warning in these cases. - promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, - useDeviceLocs, useDeviceSyms); - cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data, - stmtCtx, clauseOps); +static mlir::omp::MasterOp +genMasterOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc) { + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), + /*resultTypes=*/mlir::TypeRange()); +} + +static mlir::omp::OrderedOp +genOrderedOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + TODO(loc, "OMPD_ordered"); + return nullptr; +} - auto dataOp = converter.getFirOpBuilder().create( - currentLocation, clauseOps); +static mlir::omp::OrderedRegionOp +genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::OrderedRegionClauseOps clauseOps; + genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps); - genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp, - useDeviceTypes, useDeviceLocs, useDeviceSyms, - currentLocation); - return dataOp; + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), + clauseOps); } -template -static OpTy genTargetEnterExitDataUpdateOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { +static mlir::omp::ParallelOp +genParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; - mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; + mlir::omp::ParallelClauseOps clauseOps; + llvm::SmallVector privateSyms; + llvm::SmallVector reductionTypes; + llvm::SmallVector reductionSyms; + genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc, + /*processReduction=*/!outerCombined, clauseOps, + reductionTypes, reductionSyms); - // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. - [[maybe_unused]] llvm::omp::Directive directive; - if constexpr (std::is_same_v) { - directive = llvm::omp::Directive::OMPD_target_enter_data; - } else if constexpr (std::is_same_v) { - directive = llvm::omp::Directive::OMPD_target_exit_data; - } else if constexpr (std::is_same_v) { - directive = llvm::omp::Directive::OMPD_target_update; - } else { - return nullptr; - } + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + return reductionSyms; + }; - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(directive, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processDepend(clauseOps); - cp.processNowait(clauseOps); + OpWithBodyGenInfo genInfo = + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList) + .setReductions(&reductionSyms, &reductionTypes) + .setGenRegionEntryCb(reductionCallback); - if constexpr (std::is_same_v) { - cp.processMotionClauses(stmtCtx, clauseOps); - cp.processMotionClauses(stmtCtx, clauseOps); - } else { - cp.processMap(currentLocation, directive, stmtCtx, clauseOps); - } + if (!enableDelayedPrivatization) + return genOpWithBody(genInfo, clauseOps); - return firOpBuilder.create(currentLocation, clauseOps); -} + bool privatize = !outerCombined; + DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, + /*useDelayedPrivatization=*/true, &symTable); -// This functions creates a block for the body of the targetOp's region. It adds -// all the symbols present in mapSymbols as block arguments to this block. -static void -genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::omp::TargetOp &targetOp, - llvm::ArrayRef mapSyms, - llvm::ArrayRef mapSymLocs, - llvm::ArrayRef mapSymTypes, - const mlir::Location ¤tLocation) { - assert(mapSymTypes.size() == mapSymLocs.size()); + if (privatize) + dsp.processStep1(&clauseOps, &privateSyms); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Region ®ion = targetOp.getRegion(); + auto genRegionEntryCB = [&](mlir::Operation *op) { + auto parallelOp = llvm::cast(op); - auto *regionBlock = - firOpBuilder.createBlock(®ion, {}, mapSymTypes, mapSymLocs); + llvm::SmallVector reductionLocs( + clauseOps.reductionVars.size(), loc); - // Clones the `bounds` placing them inside the target region and returns them. - auto cloneBound = [&](mlir::Value bound) { - if (mlir::isMemoryEffectFree(bound.getDefiningOp())) { - mlir::Operation *clonedOp = bound.getDefiningOp()->clone(); - regionBlock->push_back(clonedOp); - return clonedOp->getResult(0); + mlir::OperandRange privateVars = parallelOp.getPrivateVars(); + mlir::Region ®ion = parallelOp.getRegion(); + + llvm::SmallVector privateVarTypes = reductionTypes; + privateVarTypes.reserve(privateVarTypes.size() + privateVars.size()); + llvm::transform(privateVars, std::back_inserter(privateVarTypes), + [](mlir::Value v) { return v.getType(); }); + + llvm::SmallVector privateVarLocs = reductionLocs; + privateVarLocs.reserve(privateVarLocs.size() + privateVars.size()); + llvm::transform(privateVars, std::back_inserter(privateVarLocs), + [](mlir::Value v) { return v.getLoc(); }); + + firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, + privateVarLocs); + + llvm::SmallVector allSymbols = + reductionSyms; + allSymbols.append(privateSyms); + for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { + converter.bindSymbol(*arg, prv); } - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported bound type"); - }; - auto cloneBounds = [cloneBound](llvm::ArrayRef bounds) { - llvm::SmallVector clonedBounds; - for (mlir::Value bound : bounds) - clonedBounds.emplace_back(cloneBound(bound)); - return clonedBounds; + return allSymbols; }; - // Bind the symbols to their corresponding block arguments. - for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) { - const mlir::BlockArgument &arg = region.getArgument(argIndex); - // Avoid capture of a reference to a structured binding. - const Fortran::semantics::Symbol *sym = argSymbol; - // Structure component symbols don't have bindings. - if (sym->owner().IsDerivedType()) - continue; - fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym); - extVal.match( - [&](const fir::BoxValue &v) { - converter.bindSymbol(*sym, - fir::BoxValue(arg, cloneBounds(v.getLBounds()), - v.getExplicitParameters(), - v.getExplicitExtents())); - }, - [&](const fir::MutableBoxValue &v) { - converter.bindSymbol( - *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()), - v.getMutableProperties())); - }, - [&](const fir::ArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()), - v.getSourceBox())); - }, - [&](const fir::CharArrayBoxValue &v) { - converter.bindSymbol( - *sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()), - cloneBounds(v.getExtents()), - cloneBounds(v.getLBounds()))); - }, - [&](const fir::CharBoxValue &v) { - converter.bindSymbol(*sym, - fir::CharBoxValue(arg, cloneBound(v.getLen()))); - }, - [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); }, - [&](const auto &) { - TODO(converter.getCurrentLocation(), - "target map clause operand unsupported type"); - }); - } + // TODO Merge with the reduction CB. + genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); + return genOpWithBody(genInfo, clauseOps); +} - // Check if cloning the bounds introduced any dependency on the outer region. - // If so, then either clone them as well if they are MemoryEffectFree, or else - // copy them to a new temporary and add them to the map and block_argument - // lists and replace their uses with the new temporary. - llvm::SetVector valuesDefinedAbove; - mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); - while (!valuesDefinedAbove.empty()) { - for (mlir::Value val : valuesDefinedAbove) { - mlir::Operation *valOp = val.getDefiningOp(); - if (mlir::isMemoryEffectFree(valOp)) { - mlir::Operation *clonedOp = valOp->clone(); - regionBlock->push_front(clonedOp); - val.replaceUsesWithIf( - clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); - } else { - auto savedIP = firOpBuilder.getInsertionPoint(); - firOpBuilder.setInsertionPointAfter(valOp); - auto copyVal = - firOpBuilder.createTemporary(val.getLoc(), val.getType()); - firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal); +static mlir::omp::SectionOp +genSectionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + // Currently only private/firstprivate clause is handled, and + // all privatization is done within `omp.section` operations. + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList)); +} - llvm::SmallVector bounds; - std::stringstream name; - firOpBuilder.setInsertionPoint(targetOp); - mlir::Value mapOp = createMapInfoOp( - firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(), - bounds, llvm::SmallVector{}, - static_cast< - std::underlying_type_t>( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT), - mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType()); - targetOp.getMapOperandsMutable().append(mapOp); - mlir::Value clonedValArg = - region.addArgument(copyVal.getType(), copyVal.getLoc()); - firOpBuilder.setInsertionPointToStart(regionBlock); - auto loadOp = firOpBuilder.create(clonedValArg.getLoc(), - clonedValArg); - val.replaceUsesWithIf( - loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) { - return use.getOwner()->getBlock() == regionBlock; - }); - firOpBuilder.setInsertionPoint(regionBlock, savedIP); - } - } - valuesDefinedAbove.clear(); - mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); - } +static mlir::omp::SectionsOp +genSectionsOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const mlir::omp::SectionsClauseOps &clauseOps) { + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(false), + clauseOps); +} - // Insert dummy instruction to remember the insertion position. The - // marker will be deleted since there are not uses. - // In the HLFIR flow there are hlfir.declares inserted above while - // setting block arguments. - mlir::Value undefMarker = firOpBuilder.create( - targetOp.getOperation()->getLoc(), firOpBuilder.getIndexType()); +static mlir::omp::SimdLoopOp +genSimdLoopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + DataSharingProcessor dsp(converter, semaCtx, clauseList, eval); + dsp.processStep1(); - // Create blocks for unstructured regions. This has to be done since - // blocks are initially allocated with the function as the parent region. - if (eval.lowerAsUnstructured()) { - Fortran::lower::createEmptyRegionBlocks( - firOpBuilder, eval.getNestedEvaluations()); - } + Fortran::lower::StatementContext stmtCtx; + mlir::omp::SimdLoopClauseOps clauseOps; + llvm::SmallVector iv; + genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc, + clauseOps, iv); - firOpBuilder.create(currentLocation); + auto *nestedEval = + getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList)); + + auto ivCallback = [&](mlir::Operation *op) { + return genLoopVars(op, converter, loc, iv); + }; + + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&clauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback), + clauseOps); +} + +static mlir::omp::SingleOp +genSingleOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList &endClauseList) { + mlir::omp::SingleClauseOps clauseOps; + genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc, + clauseOps); - // Create the insertion point after the marker. - firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); - if (genNested) - genNestedEvaluations(converter, eval); + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&beginClauseList), + clauseOps); } static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, + mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, - llvm::omp::Directive directive, bool outerCombined = false) { + bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; + + bool processHostOnlyClauses = + !llvm::cast(*converter.getModuleOp()) + .getIsTargetDevice(); + mlir::omp::TargetClauseOps clauseOps; llvm::SmallVector mapSyms; llvm::SmallVector mapSymLocs; llvm::SmallVector mapSymTypes; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); - cp.processDepend(clauseOps); - cp.processNowait(clauseOps); - cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms, - &mapSymLocs, &mapSymTypes); - // TODO Support delayed privatization. - - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_target); + genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc, + processHostOnlyClauses, /*processReduction=*/outerCombined, + clauseOps, mapSyms, mapSymLocs, mapSymTypes); // 5.8.1 Implicit Data-Mapping Attribute Rules // The following code follows the implicit data-mapping rules to map all the @@ -1131,338 +1584,145 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, }; Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap); - auto targetOp = converter.getFirOpBuilder().create( - currentLocation, clauseOps); - - genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms, - mapSymLocs, mapSymTypes, currentLocation); - - return targetOp; -} - -static mlir::omp::TeamsOp -genTeamsOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList, - bool outerCombined = false) { - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TeamsClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processAllocate(clauseOps); - cp.processDefault(); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); - // TODO Support delayed privatization. - - cp.processTODO(currentLocation, - llvm::omp::Directive::OMPD_teams); - - return genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setOuterCombined(outerCombined) - .setClauses(&clauseList), - clauseOps); -} - -/// Extract the list of function and variable symbols affected by the given -/// 'declare target' directive and return the intended device type for them. -static void getDeclareTargetInfo( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, - mlir::omp::DeclareTargetClauseOps &clauseOps, - llvm::SmallVectorImpl &symbolAndClause) { - const auto &spec = std::get( - declareTargetConstruct.t); - if (const auto *objectList{ - Fortran::parser::Unwrap(spec.u)}) { - ObjectList objects{makeList(*objectList, semaCtx)}; - // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, - symbolAndClause); - } else if (const auto *clauseList{ - Fortran::parser::Unwrap( - spec.u)}) { - if (clauseList->v.empty()) { - // Case: declare target, implicit capture of function - symbolAndClause.emplace_back( - mlir::omp::DeclareTargetCaptureClause::to, - eval.getOwningProcedure()->getSubprogramSymbol()); - } - - ClauseProcessor cp(converter, semaCtx, *clauseList); - cp.processTo(symbolAndClause); - cp.processEnter(symbolAndClause); - cp.processLink(symbolAndClause); - cp.processDeviceType(clauseOps); - cp.processTODO(converter.getCurrentLocation(), - llvm::omp::Directive::OMPD_declare_target); - } -} - -static void collectDeferredDeclareTargets( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, - llvm::SmallVectorImpl - &deferredDeclareTarget) { - mlir::omp::DeclareTargetClauseOps clauseOps; - llvm::SmallVector symbolAndClause; - getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, - clauseOps, symbolAndClause); - // Return the device type only if at least one of the targets for the - // directive is a function or subroutine - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol(converter.mangleName( - std::get(symClause))); - - if (!op) { - deferredDeclareTarget.push_back({std::get<0>(symClause), - clauseOps.deviceType, - std::get<1>(symClause)}); - } - } -} - -static std::optional -getDeclareTargetFunctionDevice( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - mlir::omp::DeclareTargetClauseOps clauseOps; - llvm::SmallVector symbolAndClause; - getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, - clauseOps, symbolAndClause); - - // Return the device type only if at least one of the targets for the - // directive is a function or subroutine - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol(converter.mangleName( - std::get(symClause))); - - if (mlir::isa_and_nonnull(op)) - return clauseOps.deviceType; - } - - return std::nullopt; -} - -//===----------------------------------------------------------------------===// -// genOMP() Code generation helper functions -//===----------------------------------------------------------------------===// - -static void -genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - const auto &directive = - std::get( - simpleStandaloneConstruct.t); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto &opClauseList = - std::get(simpleStandaloneConstruct.t); - mlir::Location currentLocation = converter.genLocation(directive.source); - - switch (directive.v) { - default: - break; - case llvm::omp::Directive::OMPD_barrier: - firOpBuilder.create(currentLocation); - break; - case llvm::omp::Directive::OMPD_taskwait: { - mlir::omp::TaskwaitClauseOps clauseOps; - ClauseProcessor cp(converter, semaCtx, opClauseList); - cp.processTODO( - currentLocation, llvm::omp::Directive::OMPD_taskwait); - firOpBuilder.create(currentLocation, clauseOps); - break; - } - case llvm::omp::Directive::OMPD_taskyield: - firOpBuilder.create(currentLocation); - break; - case llvm::omp::Directive::OMPD_target_data: - genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, - opClauseList); - break; - case llvm::omp::Directive::OMPD_target_enter_data: - genTargetEnterExitDataUpdateOp( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_target_exit_data: - genTargetEnterExitDataUpdateOp( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_target_update: - genTargetEnterExitDataUpdateOp( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_ordered: - TODO(currentLocation, "OMPD_ordered"); - } -} - -static void -genOmpFlush(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - llvm::SmallVector operandRange; - if (const auto &ompObjectList = - std::get>( - flushConstruct.t)) - genObjectList2(*ompObjectList, converter, operandRange); - const auto &memOrderClause = - std::get>>( - flushConstruct.t); - if (memOrderClause && memOrderClause->size() > 0) - TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); - converter.getFirOpBuilder().create( - converter.getCurrentLocation(), operandRange); -} - -static llvm::SmallVector -genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef args) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto ®ion = op->getRegion(0); - - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : args) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - llvm::SmallVector tiv(args.size(), loopVarType); - llvm::SmallVector locs(args.size(), loc); - firOpBuilder.createBlock(®ion, {}, tiv, locs); - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - mlir::Operation *storeOp = nullptr; - for (auto [argIndex, argSymbol] : llvm::enumerate(args)) { - mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex)); - storeOp = - createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); - } - firOpBuilder.setInsertionPointAfter(storeOp); - - return llvm::SmallVector(args); -} - -static llvm::SmallVector -genLoopAndReductionVars( - mlir::Operation *op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef loopArgs, - llvm::ArrayRef reductionArgs, - llvm::ArrayRef reductionTypes) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - llvm::SmallVector blockArgTypes; - llvm::SmallVector blockArgLocs; - blockArgTypes.reserve(loopArgs.size() + reductionArgs.size()); - blockArgLocs.reserve(blockArgTypes.size()); - mlir::Block *entryBlock; - - if (loopArgs.size()) { - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : loopArgs) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(), - loopVarType); - std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc); - } - if (reductionArgs.size()) { - llvm::copy(reductionTypes, std::back_inserter(blockArgTypes)); - std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc); - } - entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes, - blockArgLocs); - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - if (loopArgs.size()) { - mlir::Operation *storeOp = nullptr; - for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) { - mlir::Value indexVal = - fir::getBase(op->getRegion(0).front().getArgument(argIndex)); - storeOp = - createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); - } - firOpBuilder.setInsertionPointAfter(storeOp); - } - // Bind the reduction arguments to their block arguments - for (auto [arg, prv] : llvm::zip_equal( - reductionArgs, - llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) { - converter.bindSymbol(*arg, prv); - } - - return llvm::SmallVector(loopArgs); + auto targetOp = firOpBuilder.create(loc, clauseOps); + genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms, + mapSymLocs, mapSymTypes, loc); + return targetOp; } -static void -createSimdLoop(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - llvm::omp::Directive ompDirective, - const Fortran::parser::OmpClauseList &loopOpClauseList, - mlir::Location loc) { +static mlir::omp::TargetDataOp +genTargetDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + Fortran::lower::StatementContext stmtCtx; + mlir::omp::TargetDataClauseOps clauseOps; + llvm::SmallVector useDeviceTypes; + llvm::SmallVector useDeviceLocs; + llvm::SmallVector useDeviceSyms; + genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps, + useDeviceTypes, useDeviceLocs, useDeviceSyms); + + auto targetDataOp = + converter.getFirOpBuilder().create(loc, + clauseOps); + + genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, targetDataOp, + useDeviceTypes, useDeviceLocs, useDeviceSyms, loc); + return targetDataOp; +} + +template +static OpTy genTargetEnterExitUpdateDataOp( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval); - dsp.processStep1(); + Fortran::lower::StatementContext stmtCtx; + + // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. + [[maybe_unused]] llvm::omp::Directive directive; + if constexpr (std::is_same_v) { + directive = llvm::omp::Directive::OMPD_target_enter_data; + } else if constexpr (std::is_same_v) { + directive = llvm::omp::Directive::OMPD_target_exit_data; + } else if constexpr (std::is_same_v) { + directive = llvm::omp::Directive::OMPD_target_update; + } else { + llvm_unreachable("Unexpected TARGET DATA construct"); + } + + mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; + genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList, + loc, directive, clauseOps); + + return firOpBuilder.create(loc, clauseOps); +} +static mlir::omp::TaskOp +genTaskOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::omp::SimdLoopClauseOps clauseOps; - llvm::SmallVector iv; + mlir::omp::TaskClauseOps clauseOps; + genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); - ClauseProcessor cp(converter, semaCtx, loopOpClauseList); - cp.processCollapse(loc, eval, clauseOps, iv); - cp.processReduction(loc, clauseOps); - cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); - cp.processSimdlen(clauseOps); - cp.processSafelen(clauseOps); - clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); - // TODO Support delayed privatization. + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList), + clauseOps); +} - cp.processTODO(loc, ompDirective); +static mlir::omp::TaskgroupOp +genTaskgroupOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::TaskgroupClauseOps clauseOps; + genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps); - auto *nestedEval = getCollapsedLoopEval( - eval, Fortran::lower::getCollapseValue(loopOpClauseList)); + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList), + clauseOps); +} - auto ivCallback = [&](mlir::Operation *op) { - return genLoopVars(op, converter, loc, iv); - }; +static mlir::omp::TaskloopOp +genTaskloopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + TODO(loc, "Taskloop construct"); +} - genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&loopOpClauseList) - .setDataSharingProcessor(&dsp) - .setGenRegionEntryCb(ivCallback), +static mlir::omp::TaskwaitOp +genTaskwaitOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::TaskwaitClauseOps clauseOps; + genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps); + return converter.getFirOpBuilder().create(loc, + clauseOps); +} + +static mlir::omp::TaskyieldOp +genTaskyieldOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc) { + return converter.getFirOpBuilder().create(loc); +} + +static mlir::omp::TeamsOp +genTeamsOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { + Fortran::lower::StatementContext stmtCtx; + mlir::omp::TeamsClauseOps clauseOps; + genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + + return genOpWithBody( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList), clauseOps); } -static void createWsloop(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - llvm::omp::Directive ompDirective, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, - mlir::Location loc) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); +static mlir::omp::WsloopOp +genWsloopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList) { DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); dsp.processStep1(); @@ -1471,30 +1731,9 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, llvm::SmallVector iv; llvm::SmallVector reductionTypes; llvm::SmallVector reductionSyms; - - ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processCollapse(loc, eval, clauseOps, iv); - cp.processSchedule(stmtCtx, clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); - cp.processOrdered(clauseOps); - clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); - // TODO Support delayed privatization. - - if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) - clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - - cp.processTODO(loc, - ompDirective); - - // In FORTRAN `nowait` clause occur at the end of `omp do` directive. - // i.e - // !$omp do - // <...> - // !$omp end do nowait - if (endClauseList) { - ClauseProcessor ecp(converter, semaCtx, *endClauseList); - ecp.processNowait(clauseOps); - } + genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList, + endClauseList, loc, clauseOps, iv, reductionTypes, + reductionSyms); auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(beginClauseList)); @@ -1504,7 +1743,7 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, reductionTypes); }; - genOpWithBody( + return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) .setClauses(&beginClauseList) .setDataSharingProcessor(&dsp) @@ -1513,7 +1752,11 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, clauseOps); } -static void createSimdWsloop( +//===----------------------------------------------------------------------===// +// Code generation functions for composite constructs +//===----------------------------------------------------------------------===// + +static void genCompositeDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, @@ -1521,7 +1764,7 @@ static void createSimdWsloop( const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processTODO(loc, + clause::Order, clause::Safelen, clause::Simdlen>(loc, ompDirective); // TODO: Add support for vectorization - add vectorization hints inside loop // body. @@ -1531,34 +1774,7 @@ static void createSimdWsloop( // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - createWsloop(converter, semaCtx, eval, ompDirective, beginClauseList, - endClauseList, loc); -} - -static void -markDeclareTarget(mlir::Operation *op, - Fortran::lower::AbstractConverter &converter, - mlir::omp::DeclareTargetCaptureClause captureClause, - mlir::omp::DeclareTargetDeviceType deviceType) { - // TODO: Add support for program local variables with declare target applied - auto declareTargetOp = llvm::dyn_cast(op); - if (!declareTargetOp) - fir::emitFatalError( - converter.getCurrentLocation(), - "Attempt to apply declare target on unsupported operation"); - - // The function or global already has a declare target applied to it, very - // likely through implicit capture (usage in another declare target - // function/subroutine). It should be marked as any if it has been assigned - // both host and nohost, else we skip, as there is no change - if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) - declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, - captureClause); - return; - } - - declareTargetOp.setDeclareTarget(deviceType, captureClause); + genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); } //===----------------------------------------------------------------------===// @@ -1653,6 +1869,102 @@ genOMP(Fortran::lower::AbstractConverter &converter, ompDeclConstruct.u); } +//===----------------------------------------------------------------------===// +// OpenMPStandaloneConstruct visitors +//===----------------------------------------------------------------------===// + +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPSimpleStandaloneConstruct + &simpleStandaloneConstruct) { + const auto &directive = + std::get( + simpleStandaloneConstruct.t); + const auto &clauseList = + std::get(simpleStandaloneConstruct.t); + mlir::Location currentLocation = converter.genLocation(directive.source); + + switch (directive.v) { + default: + break; + case llvm::omp::Directive::OMPD_barrier: + genBarrierOp(converter, semaCtx, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_taskwait: + genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_taskyield: + genTaskyieldOp(converter, semaCtx, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_target_data: + genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, + currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_enter_data: + genTargetEnterExitUpdateDataOp( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_exit_data: + genTargetEnterExitUpdateDataOp( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_update: + genTargetEnterExitUpdateDataOp( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_ordered: + genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList); + break; + } +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { + const auto &verbatim = std::get(flushConstruct.t); + const auto &objectList = + std::get>(flushConstruct.t); + const auto &clauseList = + std::get>>( + flushConstruct.t); + mlir::Location currentLocation = converter.genLocation(verbatim.source); + genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList); +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); +} + +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPCancellationPointConstruct + &cancellationPointConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { + std::visit( + [&](auto &&s) { return genOMP(converter, symTable, semaCtx, eval, s); }, + standaloneConstruct.u); +} + //===----------------------------------------------------------------------===// // OpenMPConstruct visitors //===----------------------------------------------------------------------===// @@ -1782,7 +2094,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_target: genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, directive.v); + beginClauseList); break; case llvm::omp::Directive::OMPD_target_data: genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, @@ -1798,8 +2110,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_teams: genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, - /*outerCombined=*/false); + beginClauseList); break; case llvm::omp::Directive::OMPD_workshare: // FIXME: Workshare is not a commonly used OpenMP construct, an @@ -1821,8 +2132,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet) .test(directive.v)) { genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList, directive.v, - /*outerCombined=*/true); + beginClauseList, /*outerCombined=*/true); combinedDirective = true; } if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet) @@ -1859,44 +2169,13 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - std::string name; - const Fortran::parser::OmpCriticalDirective &cd = + const auto &cd = std::get(criticalConstruct.t); - if (std::get>(cd.t).has_value()) { - name = - std::get>(cd.t).value().ToString(); - } - - mlir::omp::CriticalOp criticalOp = [&]() { - if (name.empty()) { - return firOpBuilder.create( - currentLocation, mlir::FlatSymbolRefAttr()); - } - - mlir::ModuleOp module = firOpBuilder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); - auto global = module.lookupSymbol(name); - if (!global) { - mlir::omp::CriticalClauseOps clauseOps; - const auto &clauseList = std::get(cd.t); - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processHint(clauseOps); - clauseOps.nameAttr = - mlir::StringAttr::get(firOpBuilder.getContext(), name); - - global = modBuilder.create(currentLocation, - clauseOps); - } - - return firOpBuilder.create( - currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), - global.getSymName())); - }(); - auto genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval); - createBodyOfOp(criticalOp, genInfo); + const auto &clauseList = std::get(cd.t); + const auto &name = std::get>(cd.t); + mlir::Location currentLocation = converter.getCurrentLocation(); + genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, + clauseList, name); } static void @@ -1915,7 +2194,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); - const auto &loopOpClauseList = + const auto &beginClauseList = std::get(beginLoopDirective.t); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); @@ -1936,33 +2215,31 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, bool validDirective = false; if (llvm::omp::topTaskloopSet.test(ompDirective)) { validDirective = true; - TODO(currentLocation, "Taskloop construct"); + genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList); } else { // Create omp.{target, teams, distribute, parallel} nested operations if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genTargetOp(converter, semaCtx, eval, /*genNested=*/false, - currentLocation, loopOpClauseList, ompDirective, - /*outerCombined=*/true); + currentLocation, beginClauseList, /*outerCombined=*/true); } if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - loopOpClauseList, - /*outerCombined=*/true); + beginClauseList, /*outerCombined=*/true); } if (llvm::omp::allDistributeSet.test(ompDirective)) { validDirective = true; - TODO(currentLocation, "Distribute construct"); + genDistributeOp(converter, semaCtx, eval, /*genNested=*/false, + currentLocation, beginClauseList); } if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false, - currentLocation, loopOpClauseList, - /*outerCombined=*/true); + currentLocation, beginClauseList, /*outerCombined=*/true); } } if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective)) @@ -1976,17 +2253,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, if (llvm::omp::allDoSimdSet.test(ompDirective)) { // 2.9.3.2 Workshare SIMD construct - createSimdWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - endClauseList, currentLocation); - + genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList, + endClauseList, currentLocation); } else if (llvm::omp::allSimdSet.test(ompDirective)) { // 2.9.3.1 SIMD construct - createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - currentLocation); - genOpenMPReduction(converter, semaCtx, loopOpClauseList); + genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList); + genOpenMPReduction(converter, semaCtx, beginClauseList); } else { - createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - endClauseList, currentLocation); + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, + endClauseList); } } @@ -2006,44 +2281,39 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - mlir::Location currentLocation = converter.getCurrentLocation(); - mlir::omp::SectionsClauseOps clauseOps; const auto &beginSectionsDirective = std::get(sectionsConstruct.t); - const auto §ionsClauseList = + const auto &beginClauseList = std::get(beginSectionsDirective.t); // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region - ClauseProcessor cp(converter, semaCtx, sectionsClauseList); - cp.processSectionsReduction(currentLocation, clauseOps); - cp.processAllocate(clauseOps); - // TODO Support delayed privatization. + mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::omp::SectionsClauseOps clauseOps; + genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation, + /*clausesFromBeginSections=*/true, clauseOps); + // Parallel wrapper of PARALLEL SECTIONS construct llvm::omp::Directive dir = std::get(beginSectionsDirective.t) .v; - - // Parallel wrapper of PARALLEL SECTIONS construct if (dir == llvm::omp::Directive::OMPD_parallel_sections) { genParallelOp(converter, symTable, semaCtx, eval, - /*genNested=*/false, currentLocation, sectionsClauseList, + /*genNested=*/false, currentLocation, beginClauseList, /*outerCombined=*/true); } else { const auto &endSectionsDirective = std::get(sectionsConstruct.t); - const auto &endSectionsClauseList = + const auto &endClauseList = std::get(endSectionsDirective.t); - ClauseProcessor(converter, semaCtx, endSectionsClauseList) - .processNowait(clauseOps); + genSectionsClauses(converter, semaCtx, endClauseList, currentLocation, + /*clausesFromBeginSections=*/false, clauseOps); } - // SECTIONS construct - genOpWithBody( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(false), - clauseOps); + // SECTIONS construct. + genSectionsOp(converter, semaCtx, eval, currentLocation, clauseOps); + // Generate nested SECTION operations recursively. const auto §ionBlocks = std::get(sectionsConstruct.t); auto &firOpBuilder = converter.getFirOpBuilder(); @@ -2052,40 +2322,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, - sectionsClauseList); + beginClauseList); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); } } -static void -genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - genOmpSimpleStandalone(converter, semaCtx, eval, - /*genNested=*/true, - simpleStandaloneConstruct); - }, - [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - genOmpFlush(converter, semaCtx, eval, flushConstruct); - }, - [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - [&](const Fortran::parser::OpenMPCancellationPointConstruct - &cancellationPointConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - }, - standaloneConstruct.u); -} - static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semaCtx, diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90 index 821196b83c3b9..d3f2a1c7a1593 100644 --- a/flang/test/Lower/OpenMP/FIR/target.f90 +++ b/flang/test/Lower/OpenMP/FIR/target.f90 @@ -411,8 +411,8 @@ end subroutine omp_target_implicit_bounds !CHECK-LABEL: func.func @_QPomp_target_thread_limit() { subroutine omp_target_thread_limit integer :: a - !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32 !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "a"} + !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32 !CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] -> %[[ARG_0:.*]] : !fir.ref) { !CHECK: ^bb0(%[[ARG_0]]: !fir.ref): !$omp target map(tofrom: a) thread_limit(64) diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index 6f72b5a34d069..51b66327dfb24 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -490,8 +490,8 @@ end subroutine omp_target_implicit_bounds !CHECK-LABEL: func.func @_QPomp_target_thread_limit() { subroutine omp_target_thread_limit integer :: a - !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32 !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "a"} + !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32 !CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] -> %{{.*}} : !fir.ref) { !CHECK: ^bb0(%{{.*}}: !fir.ref): !$omp target map(tofrom: a) thread_limit(64) diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 index 33b5971656010..d849dd206b943 100644 --- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 +++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 @@ -21,7 +21,7 @@ subroutine only_use_device_ptr !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr() !CHECK: omp.target_data use_device_ptr({{.*}} : !fir.ref>) use_device_addr(%{{.*}}, %{{.*}} : !fir.ref>>>, !fir.ref>>>) { -!CHECK: ^bb0(%{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>>>): +!CHECK: ^bb0(%{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>): subroutine mix_use_device_ptr_and_addr use iso_c_binding integer, pointer, dimension(:) :: array @@ -47,7 +47,7 @@ subroutine only_use_device_addr !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map() !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref, !fir.ref) use_device_ptr(%{{.*}} : !fir.ref>) use_device_addr(%{{.*}}, %{{.*}} : !fir.ref>>>, !fir.ref>>>) { -!CHECK: ^bb0(%{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>>>): +!CHECK: ^bb0(%{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>): subroutine mix_use_device_ptr_and_addr_and_map use iso_c_binding integer :: i, j From ec0ed50b0d5f9606f0e9a1a3a9999f601bec310f Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 29 Mar 2024 13:57:40 +0000 Subject: [PATCH 04/13] [Flang][OpenMP][Lower] Refactor lowering of compound constructs This patch simplifies the lowering from PFT to MLIR of OpenMP compound constructs (i.e. combined and composite). The new approach consists of iteratively processing the outermost leaf construct of the given combined construct until it cannot be split further. Both leaf constructs and composite ones have `gen...()` functions that are called when appropriate. This approach enables treating a leaf construct the same way regardless of if it appeared as part of a combined construct, and it also enables the lowering of composite constructs as a single unit. Previous corner cases are now handled in a more straightforward way and comments pointing to the relevant spec section are added. Directive sets are also completed with missing LOOP related constructs. --- .../flang/Semantics/openmp-directive-sets.h | 57 ++- flang/lib/Lower/OpenMP/OpenMP.cpp | 432 ++++++++++++------ 2 files changed, 335 insertions(+), 154 deletions(-) diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index 91773ae3ea9a3..842d251b682aa 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -32,14 +32,14 @@ static const OmpDirectiveSet topDistributeSet{ static const OmpDirectiveSet allDistributeSet{ OmpDirectiveSet{ - llvm::omp::OMPD_target_teams_distribute, - llvm::omp::OMPD_target_teams_distribute_parallel_do, - llvm::omp::OMPD_target_teams_distribute_parallel_do_simd, - llvm::omp::OMPD_target_teams_distribute_simd, - llvm::omp::OMPD_teams_distribute, - llvm::omp::OMPD_teams_distribute_parallel_do, - llvm::omp::OMPD_teams_distribute_parallel_do_simd, - llvm::omp::OMPD_teams_distribute_simd, + Directive::OMPD_target_teams_distribute, + Directive::OMPD_target_teams_distribute_parallel_do, + Directive::OMPD_target_teams_distribute_parallel_do_simd, + Directive::OMPD_target_teams_distribute_simd, + Directive::OMPD_teams_distribute, + Directive::OMPD_teams_distribute_parallel_do, + Directive::OMPD_teams_distribute_parallel_do_simd, + Directive::OMPD_teams_distribute_simd, } | topDistributeSet, }; @@ -63,10 +63,24 @@ static const OmpDirectiveSet allDoSet{ } | topDoSet, }; +static const OmpDirectiveSet topLoopSet{ + Directive::OMPD_loop, +}; + +static const OmpDirectiveSet allLoopSet{ + OmpDirectiveSet{ + Directive::OMPD_parallel_loop, + Directive::OMPD_target_parallel_loop, + Directive::OMPD_target_teams_loop, + Directive::OMPD_teams_loop, + } | topLoopSet, +}; + static const OmpDirectiveSet topParallelSet{ Directive::OMPD_parallel, Directive::OMPD_parallel_do, Directive::OMPD_parallel_do_simd, + Directive::OMPD_parallel_loop, Directive::OMPD_parallel_masked_taskloop, Directive::OMPD_parallel_masked_taskloop_simd, Directive::OMPD_parallel_master_taskloop, @@ -82,6 +96,7 @@ static const OmpDirectiveSet allParallelSet{ Directive::OMPD_target_parallel, Directive::OMPD_target_parallel_do, Directive::OMPD_target_parallel_do_simd, + Directive::OMPD_target_parallel_loop, Directive::OMPD_target_teams_distribute_parallel_do, Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_parallel_do, @@ -118,12 +133,14 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_parallel, Directive::OMPD_target_parallel_do, Directive::OMPD_target_parallel_do_simd, + Directive::OMPD_target_parallel_loop, Directive::OMPD_target_simd, Directive::OMPD_target_teams, Directive::OMPD_target_teams_distribute, Directive::OMPD_target_teams_distribute_parallel_do, Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, + Directive::OMPD_target_teams_loop, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -156,11 +173,12 @@ static const OmpDirectiveSet topTeamsSet{ static const OmpDirectiveSet allTeamsSet{ OmpDirectiveSet{ - llvm::omp::OMPD_target_teams, - llvm::omp::OMPD_target_teams_distribute, - llvm::omp::OMPD_target_teams_distribute_parallel_do, - llvm::omp::OMPD_target_teams_distribute_parallel_do_simd, - llvm::omp::OMPD_target_teams_distribute_simd, + Directive::OMPD_target_teams, + Directive::OMPD_target_teams_distribute, + Directive::OMPD_target_teams_distribute_parallel_do, + Directive::OMPD_target_teams_distribute_parallel_do_simd, + Directive::OMPD_target_teams_distribute_simd, + Directive::OMPD_target_teams_loop, } | topTeamsSet, }; @@ -178,6 +196,14 @@ static const OmpDirectiveSet allDistributeSimdSet{ static const OmpDirectiveSet allDoSimdSet{allDoSet & allSimdSet}; static const OmpDirectiveSet allTaskloopSimdSet{allTaskloopSet & allSimdSet}; +static const OmpDirectiveSet compositeConstructSet{ + Directive::OMPD_distribute_parallel_do, + Directive::OMPD_distribute_parallel_do_simd, + Directive::OMPD_distribute_simd, + Directive::OMPD_do_simd, + Directive::OMPD_taskloop_simd, +}; + static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_master, Directive::OMPD_ordered, @@ -201,12 +227,14 @@ static const OmpDirectiveSet loopConstructSet{ Directive::OMPD_distribute_simd, Directive::OMPD_do, Directive::OMPD_do_simd, + Directive::OMPD_loop, Directive::OMPD_masked_taskloop, Directive::OMPD_masked_taskloop_simd, Directive::OMPD_master_taskloop, Directive::OMPD_master_taskloop_simd, Directive::OMPD_parallel_do, Directive::OMPD_parallel_do_simd, + Directive::OMPD_parallel_loop, Directive::OMPD_parallel_masked_taskloop, Directive::OMPD_parallel_masked_taskloop_simd, Directive::OMPD_parallel_master_taskloop, @@ -214,17 +242,20 @@ static const OmpDirectiveSet loopConstructSet{ Directive::OMPD_simd, Directive::OMPD_target_parallel_do, Directive::OMPD_target_parallel_do_simd, + Directive::OMPD_target_parallel_loop, Directive::OMPD_target_simd, Directive::OMPD_target_teams_distribute, Directive::OMPD_target_teams_distribute_parallel_do, Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, + Directive::OMPD_target_teams_loop, Directive::OMPD_taskloop, Directive::OMPD_taskloop_simd, Directive::OMPD_teams_distribute, Directive::OMPD_teams_distribute_parallel_do, Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, + Directive::OMPD_teams_loop, Directive::OMPD_tile, Directive::OMPD_unroll, }; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 692d81f9188be..edae453972d3d 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -710,6 +710,81 @@ genOpenMPReduction(Fortran::lower::AbstractConverter &converter, } } +/// Split a combined directive into an outer leaf directive and the (possibly +/// combined) rest of the combined directive. Composite directives and +/// non-compound directives are not split, in which case it will return the +/// input directive as its first output and an empty value as its second output. +static std::pair> +splitCombinedDirective(llvm::omp::Directive dir) { + using D = llvm::omp::Directive; + switch (dir) { + case D::OMPD_masked_taskloop: + return {D::OMPD_masked, D::OMPD_taskloop}; + case D::OMPD_masked_taskloop_simd: + return {D::OMPD_masked, D::OMPD_taskloop_simd}; + case D::OMPD_master_taskloop: + return {D::OMPD_master, D::OMPD_taskloop}; + case D::OMPD_master_taskloop_simd: + return {D::OMPD_master, D::OMPD_taskloop_simd}; + case D::OMPD_parallel_do: + return {D::OMPD_parallel, D::OMPD_do}; + case D::OMPD_parallel_do_simd: + return {D::OMPD_parallel, D::OMPD_do_simd}; + case D::OMPD_parallel_masked: + return {D::OMPD_parallel, D::OMPD_masked}; + case D::OMPD_parallel_masked_taskloop: + return {D::OMPD_parallel, D::OMPD_masked_taskloop}; + case D::OMPD_parallel_masked_taskloop_simd: + return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd}; + case D::OMPD_parallel_master: + return {D::OMPD_parallel, D::OMPD_master}; + case D::OMPD_parallel_master_taskloop: + return {D::OMPD_parallel, D::OMPD_master_taskloop}; + case D::OMPD_parallel_master_taskloop_simd: + return {D::OMPD_parallel, D::OMPD_master_taskloop_simd}; + case D::OMPD_parallel_sections: + return {D::OMPD_parallel, D::OMPD_sections}; + case D::OMPD_parallel_workshare: + return {D::OMPD_parallel, D::OMPD_workshare}; + case D::OMPD_target_parallel: + return {D::OMPD_target, D::OMPD_parallel}; + case D::OMPD_target_parallel_do: + return {D::OMPD_target, D::OMPD_parallel_do}; + case D::OMPD_target_parallel_do_simd: + return {D::OMPD_target, D::OMPD_parallel_do_simd}; + case D::OMPD_target_simd: + return {D::OMPD_target, D::OMPD_simd}; + case D::OMPD_target_teams: + return {D::OMPD_target, D::OMPD_teams}; + case D::OMPD_target_teams_distribute: + return {D::OMPD_target, D::OMPD_teams_distribute}; + case D::OMPD_target_teams_distribute_parallel_do: + return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do}; + case D::OMPD_target_teams_distribute_parallel_do_simd: + return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd}; + case D::OMPD_target_teams_distribute_simd: + return {D::OMPD_target, D::OMPD_teams_distribute_simd}; + case D::OMPD_teams_distribute: + return {D::OMPD_teams, D::OMPD_distribute}; + case D::OMPD_teams_distribute_parallel_do: + return {D::OMPD_teams, D::OMPD_distribute_parallel_do}; + case D::OMPD_teams_distribute_parallel_do_simd: + return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd}; + case D::OMPD_teams_distribute_simd: + return {D::OMPD_teams, D::OMPD_distribute_simd}; + case D::OMPD_parallel_loop: + return {D::OMPD_parallel, D::OMPD_loop}; + case D::OMPD_target_parallel_loop: + return {D::OMPD_target, D::OMPD_parallel_loop}; + case D::OMPD_target_teams_loop: + return {D::OMPD_target, D::OMPD_teams_loop}; + case D::OMPD_teams_loop: + return {D::OMPD_teams, D::OMPD_loop}; + default: + return {dir, std::nullopt}; + } +} + //===----------------------------------------------------------------------===// // Op body generation helper structures and functions //===----------------------------------------------------------------------===// @@ -1962,16 +2037,44 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, // Code generation functions for composite constructs //===----------------------------------------------------------------------===// -static void genCompositeDoSimd( +static void genCompositeDistributeParallelDo( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + TODO(loc, "Composite DISTRIBUTE PARALLEL DO"); +} + +static void genCompositeDistributeParallelDoSimd( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD"); +} + +static void genCompositeDistributeSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, + Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + TODO(loc, "Composite DISTRIBUTE SIMD"); +} + +static void +genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList, + mlir::Location loc) { ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processTODO(loc, - ompDirective); + clause::Order, clause::Safelen, clause::Simdlen>( + loc, llvm::omp::OMPD_do_simd); // TODO: Add support for vectorization - add vectorization hints inside loop // body. // OpenMP standard does not specify the length of vector instructions. @@ -1983,6 +2086,16 @@ static void genCompositeDoSimd( genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); } +static void +genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList, + mlir::Location loc) { + TODO(loc, "Composite TASKLOOP SIMD"); +} + //===----------------------------------------------------------------------===// // OpenMPDeclarativeConstruct visitors //===----------------------------------------------------------------------===// @@ -2240,13 +2353,18 @@ genOMP(Fortran::lower::AbstractConverter &converter, std::get(blockConstruct.t); const auto &endBlockDirective = std::get(blockConstruct.t); - const auto &directive = - std::get(beginBlockDirective.t); + mlir::Location currentLocation = + converter.genLocation(beginBlockDirective.source); + const auto origDirective = + std::get(beginBlockDirective.t).v; const auto &beginClauseList = std::get(beginBlockDirective.t); const auto &endClauseList = std::get(endBlockDirective.t); + assert(llvm::omp::blockConstructSet.test(origDirective) && + "Expected block construct"); + for (const Fortran::parser::OmpClause &clause : beginClauseList.v) { mlir::Location clauseLocation = converter.genLocation(clause.source); if (!std::get_if(&clause.u) && @@ -2280,93 +2398,74 @@ genOMP(Fortran::lower::AbstractConverter &converter, TODO(clauseLocation, "OpenMP Block construct clause"); } - bool singleDirective = true; - mlir::Location currentLocation = converter.genLocation(directive.source); - switch (directive.v) { - case llvm::omp::Directive::OMPD_master: - genMasterOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation); - break; - case llvm::omp::Directive::OMPD_ordered: - genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, beginClauseList); - break; - case llvm::omp::Directive::OMPD_parallel: - genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true, - currentLocation, beginClauseList); - break; - case llvm::omp::Directive::OMPD_single: - genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, endClauseList); - break; - case llvm::omp::Directive::OMPD_target: - genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, + std::optional nextDir = origDirective; + bool outermostLeafConstruct = true; + while (nextDir) { + llvm::omp::Directive leafDir; + std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir); + const bool genNested = !nextDir; + const bool outerCombined = outermostLeafConstruct && nextDir.has_value(); + switch (leafDir) { + case llvm::omp::Directive::OMPD_master: + // 2.16 MASTER construct. + genMasterOp(converter, semaCtx, eval, genNested, currentLocation); + break; + case llvm::omp::Directive::OMPD_ordered: + // 2.17.9 ORDERED construct. + genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_parallel: + // 2.6 PARALLEL construct. + genParallelOp(converter, symTable, semaCtx, eval, genNested, + currentLocation, beginClauseList, outerCombined); + break; + case llvm::omp::Directive::OMPD_single: + // 2.8.2 SINGLE construct. + genSingleOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList, endClauseList); + break; + case llvm::omp::Directive::OMPD_target: + // 2.12.5 TARGET construct. + genTargetOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList, outerCombined); + break; + case llvm::omp::Directive::OMPD_target_data: + // 2.12.2 TARGET DATA construct. + genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_task: + // 2.10.1 TASK construct. + genTaskOp(converter, semaCtx, eval, genNested, currentLocation, beginClauseList); - break; - case llvm::omp::Directive::OMPD_target_data: - genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, beginClauseList); - break; - case llvm::omp::Directive::OMPD_task: - genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList); - break; - case llvm::omp::Directive::OMPD_taskgroup: - genTaskgroupOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, beginClauseList); - break; - case llvm::omp::Directive::OMPD_teams: - genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList); - break; - case llvm::omp::Directive::OMPD_workshare: - // FIXME: Workshare is not a commonly used OpenMP construct, an - // implementation for this feature will come later. For the codes - // that use this construct, add a single construct for now. - genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, endClauseList); - break; - default: - singleDirective = false; - break; - } - - if (singleDirective) - return; - - // Codegen for combined directives - bool combinedDirective = false; - if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet) - .test(directive.v)) { - genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList, /*outerCombined=*/true); - combinedDirective = true; - } - if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet) - .test(directive.v)) { - genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList); - combinedDirective = true; - } - if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet) - .test(directive.v)) { - bool outerCombined = - directive.v != llvm::omp::Directive::OMPD_target_parallel; - genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false, - currentLocation, beginClauseList, outerCombined); - combinedDirective = true; - } - if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet) - .test(directive.v)) { - genSingleOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList, endClauseList); - combinedDirective = true; + break; + case llvm::omp::Directive::OMPD_taskgroup: + // 2.17.6 TASKGROUP construct. + genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_teams: + // 2.7 TEAMS construct. + // FIXME Pass the outerCombined argument or rename it to better describe + // what it represents if it must always be `false` in this context. + genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_workshare: + // 2.8.3 WORKSHARE construct. + // FIXME: Workshare is not a commonly used OpenMP construct, an + // implementation for this feature will come later. For the codes + // that use this construct, add a single construct for now. + genSingleOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList, endClauseList); + break; + default: + llvm_unreachable("Unexpected block construct"); + break; + } + outermostLeafConstruct = false; } - if (!combinedDirective) - TODO(currentLocation, "Unhandled block directive (" + - llvm::omp::getOpenMPDirectiveName(directive.v) + - ")"); - - genNestedEvaluations(converter, eval); } static void @@ -2404,9 +2503,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, std::get(beginLoopDirective.t); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); - const auto ompDirective = + const auto origDirective = std::get(beginLoopDirective.t).v; + assert(llvm::omp::loopConstructSet.test(origDirective) && + "Expected loop construct"); + const auto *endClauseList = [&]() { using RetTy = const Fortran::parser::OmpClauseList *; if (auto &endLoopDirective = @@ -2418,57 +2520,105 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, return RetTy(); }(); - bool validDirective = false; - if (llvm::omp::topTaskloopSet.test(ompDirective)) { - validDirective = true; - genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList); - } else { - // Create omp.{target, teams, distribute, parallel} nested operations - if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet) - .test(ompDirective)) { - validDirective = true; - genTargetOp(converter, semaCtx, eval, /*genNested=*/false, - currentLocation, beginClauseList, /*outerCombined=*/true); - } - if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet) - .test(ompDirective)) { - validDirective = true; - genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList, /*outerCombined=*/true); - } - if (llvm::omp::allDistributeSet.test(ompDirective)) { - validDirective = true; - genDistributeOp(converter, semaCtx, eval, /*genNested=*/false, - currentLocation, beginClauseList); - } - if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet) - .test(ompDirective)) { - validDirective = true; - genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false, - currentLocation, beginClauseList, /*outerCombined=*/true); + std::optional nextDir = origDirective; + while (nextDir) { + llvm::omp::Directive leafDir; + std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir); + if (llvm::omp::compositeConstructSet.test(leafDir)) { + assert(!nextDir && "Composite construct cannot be split"); + switch (leafDir) { + case llvm::omp::Directive::OMPD_distribute_parallel_do: + // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct. + genCompositeDistributeParallelDo(converter, semaCtx, eval, + beginClauseList, endClauseList, + currentLocation); + break; + case llvm::omp::Directive::OMPD_distribute_parallel_do_simd: + // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct. + genCompositeDistributeParallelDoSimd(converter, semaCtx, eval, + beginClauseList, endClauseList, + currentLocation); + break; + case llvm::omp::Directive::OMPD_distribute_simd: + // 2.9.4.2 DISTRIBUTE SIMD construct. + genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList, + endClauseList, currentLocation); + break; + case llvm::omp::Directive::OMPD_do_simd: + // 2.9.3.2 Worksharing-Loop SIMD construct. + genCompositeDoSimd(converter, semaCtx, eval, beginClauseList, + endClauseList, currentLocation); + break; + case llvm::omp::Directive::OMPD_taskloop_simd: + // 2.10.3 TASKLOOP SIMD construct. + genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList, + endClauseList, currentLocation); + break; + default: + llvm_unreachable("Unexpected composite construct"); + } + } else { + const bool genNested = !nextDir; + switch (leafDir) { + case llvm::omp::Directive::OMPD_distribute: + // 2.9.4.1 DISTRIBUTE construct. + genDistributeOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_do: + // 2.9.2 Worksharing-Loop construct. + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, + endClauseList); + break; + case llvm::omp::Directive::OMPD_parallel: + // 2.6 PARALLEL construct. + // FIXME This is not necessarily always the outer leaf construct of a + // combined construct in this constext (e.g. distribute parallel do). + // Maybe rename the argument if it represents something else or + // initialize it properly. + genParallelOp(converter, symTable, semaCtx, eval, genNested, + currentLocation, beginClauseList, + /*outerCombined=*/true); + break; + case llvm::omp::Directive::OMPD_simd: + // 2.9.3.1 SIMD construct. + genSimdLoopOp(converter, semaCtx, eval, currentLocation, + beginClauseList); + genOpenMPReduction(converter, semaCtx, beginClauseList); + break; + case llvm::omp::Directive::OMPD_target: + // 2.12.5 TARGET construct. + genTargetOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList, /*outerCombined=*/true); + break; + case llvm::omp::Directive::OMPD_taskloop: + // 2.10.2 TASKLOOP construct. + genTaskloopOp(converter, semaCtx, eval, currentLocation, + beginClauseList); + break; + case llvm::omp::Directive::OMPD_teams: + // 2.7 TEAMS construct. + // FIXME This is not necessarily always the outer leaf construct of a + // combined construct in this constext (e.g. target teams distribute). + // Maybe rename the argument if it represents something else or + // initialize it properly. + genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, + beginClauseList, /*outerCombined=*/true); + break; + case llvm::omp::Directive::OMPD_loop: + case llvm::omp::Directive::OMPD_masked: + case llvm::omp::Directive::OMPD_master: + case llvm::omp::Directive::OMPD_tile: + case llvm::omp::Directive::OMPD_unroll: + TODO(currentLocation, "Unhandled loop directive (" + + llvm::omp::getOpenMPDirectiveName(leafDir) + + ")"); + break; + default: + llvm_unreachable("Unexpected loop construct"); + } } } - if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective)) - validDirective = true; - - if (!validDirective) { - TODO(currentLocation, "Unhandled loop directive (" + - llvm::omp::getOpenMPDirectiveName(ompDirective) + - ")"); - } - - if (llvm::omp::allDoSimdSet.test(ompDirective)) { - // 2.9.3.2 Workshare SIMD construct - genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList, - endClauseList, currentLocation); - } else if (llvm::omp::allSimdSet.test(ompDirective)) { - // 2.9.3.1 SIMD construct - genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList); - genOpenMPReduction(converter, semaCtx, beginClauseList); - } else { - genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, - endClauseList); - } } static void From f725face892cef4faf9f17d4b549541bdbcd7e08 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 29 Mar 2024 09:20:41 -0500 Subject: [PATCH 05/13] [flang][OpenMP] Move clause/object conversion to happen early, in genOMP This removes the last use of genOmpObectList2, which has now been removed. --- flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +- flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++--------- flang/lib/Lower/OpenMP/Utils.cpp | 30 +- flang/lib/Lower/OpenMP/Utils.h | 6 +- 5 files changed, 218 insertions(+), 252 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index db7a1b8335f81..f4d659b70cfee 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -49,9 +49,8 @@ class ClauseProcessor { public: ClauseProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses) - : converter(converter), semaCtx(semaCtx), - clauses(makeClauses(clauses, semaCtx)) {} + const List &clauses) + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} // 'Unique' clauses: They can appear at most once in the clause list. bool processCollapse( diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index c11ee299c5d08..ef7b14327278e 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -78,13 +78,12 @@ class DataSharingProcessor { public: DataSharingProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &opClauseList, + const List &clauses, Fortran::lower::pft::Evaluation &eval, bool useDelayedPrivatization = false, Fortran::lower::SymMap *symTable = nullptr) : hasLastPrivateOp(false), converter(converter), - firOpBuilder(converter.getFirOpBuilder()), - clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval), + firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval), useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {} // Privatisation is split into two steps. diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index edae453972d3d..23dc25ac1ae9a 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -17,6 +17,7 @@ #include "DataSharingProcessor.h" #include "DirectivesCommon.h" #include "ReductionProcessor.h" +#include "Utils.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" @@ -310,14 +311,15 @@ static void getDeclareTargetInfo( } else if (const auto *clauseList{ Fortran::parser::Unwrap( spec.u)}) { - if (clauseList->v.empty()) { + List clauses = makeClauses(*clauseList, semaCtx); + if (clauses.empty()) { // Case: declare target, implicit capture of function symbolAndClause.emplace_back( mlir::omp::DeclareTargetCaptureClause::to, eval.getOwningProcedure()->getSubprogramSymbol()); } - ClauseProcessor cp(converter, semaCtx, *clauseList); + ClauseProcessor cp(converter, semaCtx, clauses); cp.processDeviceType(clauseOps); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); @@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { // TODO: Generate the reduction operation during lowering instead of creating // and removing operations since this is not a robust approach. Also, removing // ops in the builder (instead of a rewriter) is probably not the best approach. -static void -genOpenMPReduction(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { +static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - List clauses{makeClauses(clauseList, semaCtx)}; - for (const Clause &clause : clauses) { if (const auto &reductionClause = std::get_if(&clause.u)) { @@ -812,7 +811,7 @@ struct OpWithBodyGenInfo { return *this; } - OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) { + OpWithBodyGenInfo &setClauses(const List *value) { clauses = value; return *this; } @@ -848,7 +847,7 @@ struct OpWithBodyGenInfo { /// [in] is this an outer operation - prevents privatization. bool outerCombined = false; /// [in] list of clauses to process. - const Fortran::parser::OmpClauseList *clauses = nullptr; + const List *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; /// [in] if provided, list of reduction symbols @@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { // Code generation functions for clauses //===----------------------------------------------------------------------===// -static void genCriticalDeclareClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { +static void +genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, + llvm::StringRef name) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processHint(clauseOps); clauseOps.nameAttr = mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); } -static void genFlushClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const std::optional &objects, - const std::optional> - &clauses, - mlir::Location loc, llvm::SmallVectorImpl &operandRange) { - if (objects) - genObjectList2(*objects, converter, operandRange); - - if (clauses && clauses->size() > 0) +static void genFlushClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const ObjectList &objects, + const List &clauses, mlir::Location loc, + llvm::SmallVectorImpl &operandRange) { + genObjectList(objects, converter, operandRange); + + if (clauses.size() > 0) TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); } static void genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::OrderedRegionClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO(loc, llvm::omp::Directive::OMPD_ordered); @@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, static void genParallelClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processReduction, + mlir::omp::ParallelClauseOps &clauseOps, llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1286,8 +1282,7 @@ static void genParallelClauses( static void genSectionsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, bool clausesFromBeginSections, mlir::omp::SectionsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1304,9 +1299,8 @@ static void genSimdLoopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::SimdLoopClauseOps &clauseOps, + Fortran::lower::pft::Evaluation &eval, const List &clauses, + mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processCollapse(loc, eval, clauseOps, iv); @@ -1324,9 +1318,8 @@ static void genSimdLoopClauses( static void genSingleClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList &endClauses, - mlir::Location loc, + const List &beginClauses, + const List &endClauses, mlir::Location loc, mlir::omp::SingleClauseOps &clauseOps) { ClauseProcessor bcp(converter, semaCtx, beginClauses); bcp.processAllocate(clauseOps); @@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter, static void genTargetClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processHostOnlyClauses, bool processReduction, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processHostOnlyClauses, bool processReduction, mlir::omp::TargetClauseOps &clauseOps, llvm::SmallVectorImpl &mapSyms, llvm::SmallVectorImpl &mapSymLocs, @@ -1368,9 +1360,8 @@ static void genTargetClauses( static void genTargetDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::TargetDataClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) { @@ -1401,9 +1392,8 @@ static void genTargetDataClauses( static void genTargetEnterExitUpdateDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - llvm::omp::Directive directive, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, llvm::omp::Directive directive, mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); @@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses( static void genTaskClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter, static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskgroupClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskwaitClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO( @@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, static void genTeamsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TeamsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1482,9 +1468,8 @@ static void genWsloopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc, + Fortran::lower::pft::Evaluation &eval, const List &beginClauses, + const List &endClauses, mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps, llvm::SmallVectorImpl &iv, llvm::SmallVectorImpl &reductionTypes, @@ -1501,8 +1486,8 @@ static void genWsloopClauses( if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - if (endClauses) { - ClauseProcessor ecp(converter, semaCtx, *endClauses); + if (!endClauses.empty()) { + ClauseProcessor ecp(converter, semaCtx, endClauses); ecp.processNowait(clauseOps); } @@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp genCriticalOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, const std::optional &name) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::FlatSymbolRefAttr nameAttr; @@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter, auto global = mod.lookupSymbol(nameStr); if (!global) { mlir::omp::CriticalClauseOps clauseOps; - genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps, + genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps, nameStr); mlir::OpBuilder modBuilder(mod.getBodyRegion()); @@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp genDistributeOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { TODO(loc, "Distribute construct"); return nullptr; } @@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp genFlushOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const std::optional &objectList, - const std::optional> - &clauseList) { + const ObjectList &objects, const List &clauses) { llvm::SmallVector operandRange; - genFlushClauses(converter, semaCtx, objectList, clauseList, loc, - operandRange); + genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange); return converter.getFirOpBuilder().create( converter.getCurrentLocation(), operandRange); @@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp genOrderedOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { TODO(loc, "OMPD_ordered"); return nullptr; } @@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { mlir::omp::OrderedRegionClauseOps clauseOps; - genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps); + genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), @@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector privateSyms; llvm::SmallVector reductionTypes; llvm::SmallVector reductionSyms; - genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc, /*processReduction=*/!outerCombined, clauseOps, reductionTypes, reductionSyms); @@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList) + .setClauses(&clauses) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(reductionCallback); @@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody(genInfo, clauseOps); bool privatize = !outerCombined; - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, + DataSharingProcessor dsp(converter, semaCtx, clauses, eval, /*useDelayedPrivatization=*/true, &symTable); if (privatize) @@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp genSectionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList)); + .setClauses(&clauses)); } static mlir::omp::SectionsOp @@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp genSimdLoopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval); + const List &clauses) { + DataSharingProcessor dsp(converter, semaCtx, clauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; mlir::omp::SimdLoopClauseOps clauseOps; llvm::SmallVector iv; - genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc, - clauseOps, iv); + genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps, + iv); - auto *nestedEval = - getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopVars(op, converter, loc, iv); @@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&clauseList) + .setClauses(&clauses) .setDataSharingProcessor(&dsp) .setGenRegionEntryCb(ivCallback), clauseOps); @@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp genSingleOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList &endClauseList) { + mlir::Location loc, const List &beginClauses, + const List &endClauses) { mlir::omp::SingleClauseOps clauseOps; - genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc, + genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&beginClauseList), + .setClauses(&beginClauses), clauseOps); } @@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector mapSyms; llvm::SmallVector mapSymLocs; llvm::SmallVector mapSymTypes; - genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc, processHostOnlyClauses, /*processReduction=*/outerCombined, clauseOps, mapSyms, mapSymLocs, mapSymTypes); @@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp genTargetDataOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TargetDataClauseOps clauseOps; llvm::SmallVector useDeviceTypes; llvm::SmallVector useDeviceLocs; llvm::SmallVector useDeviceSyms; - genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps, + genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); auto targetDataOp = @@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, return targetDataOp; } -template -static OpTy genTargetEnterExitUpdateDataOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { +template static OpTy +genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + mlir::Location loc, + const List &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp( } mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; - genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList, - loc, directive, clauseOps); + genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc, + directive, clauseOps); return firOpBuilder.create(loc, clauseOps); } @@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp genTaskOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TaskClauseOps clauseOps; - genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp genTaskgroupOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { mlir::omp::TaskgroupClauseOps clauseOps; - genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp genTaskloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { TODO(loc, "Taskloop construct"); } @@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp genTaskwaitOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { mlir::omp::TaskwaitClauseOps clauseOps; - genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps); return converter.getFirOpBuilder().create(loc, clauseOps); } @@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp genTeamsOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TeamsClauseOps clauseOps; - genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp genWsloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList) { - DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); + const List &beginClauses, const List &endClauses) { + DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; @@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector iv; llvm::SmallVector reductionTypes; llvm::SmallVector reductionSyms; - genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList, - endClauseList, loc, clauseOps, iv, reductionTypes, - reductionSyms); + genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses, + loc, clauseOps, iv, reductionTypes, reductionSyms); - auto *nestedEval = getCollapsedLoopEval( - eval, Fortran::lower::getCollapseValue(beginClauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms, @@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&beginClauseList) + .setClauses(&beginClauses) .setDataSharingProcessor(&dsp) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(ivCallback), @@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO"); } @@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD"); } @@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE SIMD"); } @@ -2068,10 +2036,10 @@ static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List &beginClauses, + const List &endClauses, mlir::Location loc) { - ClauseProcessor cp(converter, semaCtx, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauses); cp.processTODO( loc, llvm::omp::OMPD_do_simd); @@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); + genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses); } static void genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite TASKLOOP SIMD"); } @@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const auto &directive = std::get( simpleStandaloneConstruct.t); - const auto &clauseList = - std::get(simpleStandaloneConstruct.t); + List clauses = makeClauses( + std::get(simpleStandaloneConstruct.t), + semaCtx); mlir::Location currentLocation = converter.genLocation(directive.source); switch (directive.v) { @@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, genBarrierOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_taskwait: - genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList); + genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_taskyield: genTaskyieldOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_target_data: genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, clauseList); + currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_enter_data: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_exit_data: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_update: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_ordered: - genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList); + genOrderedOp(converter, semaCtx, eval, currentLocation, clauses); break; } } @@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, const auto &clauseList = std::get>>( flushConstruct.t); + ObjectList objects = + objectList ? makeObjects(*objectList, semaCtx) : ObjectList{}; + List clauses = + clauseList ? makeList(*clauseList, + [&](auto &&s) { return makeClause(s.v, semaCtx); }) + : List{}; mlir::Location currentLocation = converter.genLocation(verbatim.source); - genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList); + genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses); } static void @@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, converter.genLocation(beginBlockDirective.source); const auto origDirective = std::get(beginBlockDirective.t).v; - const auto &beginClauseList = - std::get(beginBlockDirective.t); - const auto &endClauseList = - std::get(endBlockDirective.t); + List beginClauses = makeClauses( + std::get(beginBlockDirective.t), semaCtx); + List endClauses = makeClauses( + std::get(endBlockDirective.t), semaCtx); assert(llvm::omp::blockConstructSet.test(origDirective) && "Expected block construct"); - for (const Fortran::parser::OmpClause &clause : beginClauseList.v) { + for (const Clause &clause : beginClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u)) { + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } - for (const auto &clause : endClauseList.v) { + for (const Clause &clause : endClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if(&clause.u) && - !std::get_if(&clause.u)) + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u)) TODO(clauseLocation, "OpenMP Block construct clause"); } @@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_ordered: // 2.17.9 ORDERED construct. genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, outerCombined); + currentLocation, beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_single: // 2.8.2 SINGLE construct. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, outerCombined); + beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_target_data: // 2.12.2 TARGET DATA construct. genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_task: // 2.10.1 TASK construct. genTaskOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_taskgroup: // 2.17.6 TASKGROUP construct. genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. // FIXME Pass the outerCombined argument or rename it to better describe // what it represents if it must always be `false` in this context. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_workshare: // 2.8.3 WORKSHARE construct. @@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, // implementation for this feature will come later. For the codes // that use this construct, add a single construct for now. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; default: llvm_unreachable("Unexpected block construct"); @@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { const auto &cd = std::get(criticalConstruct.t); - const auto &clauseList = std::get(cd.t); + List clauses = + makeClauses(std::get(cd.t), semaCtx); const auto &name = std::get>(cd.t); mlir::Location currentLocation = converter.getCurrentLocation(); genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - clauseList, name); + clauses, name); } static void @@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); - const auto &beginClauseList = - std::get(beginLoopDirective.t); + List beginClauses = makeClauses( + std::get(beginLoopDirective.t), semaCtx); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); const auto origDirective = @@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, assert(llvm::omp::loopConstructSet.test(origDirective) && "Expected loop construct"); - const auto *endClauseList = [&]() { - using RetTy = const Fortran::parser::OmpClauseList *; + List endClauses = [&]() { if (auto &endLoopDirective = std::get>( loopConstruct.t)) { - return RetTy( - &std::get((*endLoopDirective).t)); + return makeClauses( + std::get(endLoopDirective->t), + semaCtx); } - return RetTy(); + return List{}; }(); std::optional nextDir = origDirective; @@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute_parallel_do: // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct. genCompositeDistributeParallelDo(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_parallel_do_simd: // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct. genCompositeDistributeParallelDoSimd(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_simd: // 2.9.4.2 DISTRIBUTE SIMD construct. - genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_do_simd: // 2.9.3.2 Worksharing-Loop SIMD construct. - genCompositeDoSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDoSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_taskloop_simd: // 2.10.3 TASKLOOP SIMD construct. - genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; default: llvm_unreachable("Unexpected composite construct"); @@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute: // 2.9.4.1 DISTRIBUTE construct. genDistributeOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_do: // 2.9.2 Worksharing-Loop construct. - genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, - endClauseList); + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses, + endClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. @@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, + currentLocation, beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_simd: // 2.9.3.1 SIMD construct. genSimdLoopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); - genOpenMPReduction(converter, semaCtx, beginClauseList); + beginClauses); + genOpenMPReduction(converter, semaCtx, beginClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_taskloop: // 2.10.2 TASKLOOP construct. genTaskloopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. @@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_loop: case llvm::omp::Directive::OMPD_masked: @@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { const auto &beginSectionsDirective = std::get(sectionsConstruct.t); - const auto &beginClauseList = - std::get(beginSectionsDirective.t); + List beginClauses = makeClauses( + std::get(beginSectionsDirective.t), + semaCtx); // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region mlir::Location currentLocation = converter.getCurrentLocation(); mlir::omp::SectionsClauseOps clauseOps; - genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation, + genSectionsClauses(converter, semaCtx, beginClauses, currentLocation, /*clausesFromBeginSections=*/true, clauseOps); // Parallel wrapper of PARALLEL SECTIONS construct @@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, .v; if (dir == llvm::omp::Directive::OMPD_parallel_sections) { genParallelOp(converter, symTable, semaCtx, eval, - /*genNested=*/false, currentLocation, beginClauseList, + /*genNested=*/false, currentLocation, beginClauses, /*outerCombined=*/true); } else { const auto &endSectionsDirective = std::get(sectionsConstruct.t); - const auto &endClauseList = - std::get(endSectionsDirective.t); - genSectionsClauses(converter, semaCtx, endClauseList, currentLocation, + List endClauses = makeClauses( + std::get(endSectionsDirective.t), + semaCtx); + genSectionsClauses(converter, semaCtx, endClauses, currentLocation, /*clausesFromBeginSections=*/false, clauseOps); } @@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, - beginClauseList); + beginClauses); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); } diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b9c0660aa4da8..da3f2be73e509 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -36,6 +36,17 @@ namespace Fortran { namespace lower { namespace omp { +int64_t getCollapseValue(const List &clauses) { + auto iter = llvm::find_if(clauses, [](const Clause &clause) { + return clause.id == llvm::omp::Clause::OMPC_collapse; + }); + if (iter != clauses.end()) { + const auto &collapse = std::get(iter->u); + return evaluate::ToInt64(collapse.v).value(); + } + return 1; +} + void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands) { @@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects, } } -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); - } else if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } - }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); - } -} - mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, std::size_t loopVarTypeSize) { // OpenMP runtime requires 32-bit or 64-bit loop variables. diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 4074bf73987d5..b3a9f7f30c98b 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -58,6 +58,8 @@ void gatherFuncAndVarSyms( const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl &symbolAndClause); +int64_t getCollapseValue(const List &clauses); + Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); @@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands); -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands); - } // namespace omp } // namespace lower } // namespace Fortran From 291dc48d5e0b7e0ee39681a1276bd1d63f456b01 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 1 Apr 2024 10:07:45 -0500 Subject: [PATCH 06/13] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible). --- llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 + llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++- llvm/test/TableGen/directive1.td | 19 +- llvm/test/TableGen/directive2.td | 19 +- llvm/unittests/Frontend/CMakeLists.txt | 1 + llvm/unittests/Frontend/OpenMPComposeTest.cpp | 41 ++++ llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++------- 7 files changed, 258 insertions(+), 87 deletions(-) create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h index a85cd9d344c6d..4ed47f15dfe59 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h @@ -15,4 +15,11 @@ #include "llvm/Frontend/OpenMP/OMP.h.inc" +#include "llvm/ADT/ArrayRef.h" + +namespace llvm::omp { +ArrayRef getLeafConstructs(Directive D); +Directive getCompoundConstruct(ArrayRef Parts); +} // namespace llvm::omp + #endif // LLVM_FRONTEND_OPENMP_OMP_H diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index 4f2f95392648b..dd99d3d074fd1 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -8,12 +8,74 @@ #include "llvm/Frontend/OpenMP/OMP.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/ErrorHandling.h" +#include +#include +#include + using namespace llvm; -using namespace omp; +using namespace llvm::omp; #define GEN_DIRECTIVES_IMPL #include "llvm/Frontend/OpenMP/OMP.inc" + +namespace llvm::omp { +ArrayRef getLeafConstructs(Directive D) { + auto Idx = static_cast(D); + if (Idx < 0 || Idx >= static_cast(Directive_enumSize)) + return {}; + const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; + return ArrayRef(&Row[2], &Row[2] + static_cast(Row[1])); +} + +Directive getCompoundConstruct(ArrayRef Parts) { + if (Parts.empty()) + return OMPD_unknown; + + // Parts don't have to be leafs, so expand them into leafs first. + // Store the expanded leafs in the same format as rows in the leaf + // table (generated by tablegen). + SmallVector RawLeafs(2); + for (Directive P : Parts) { + ArrayRef Ls = getLeafConstructs(P); + if (!Ls.empty()) + RawLeafs.append(Ls.begin(), Ls.end()); + else + RawLeafs.push_back(P); + } + + auto GivenLeafs{ArrayRef(RawLeafs).drop_front(2)}; + if (GivenLeafs.size() == 1) + return GivenLeafs.front(); + RawLeafs[1] = static_cast(GivenLeafs.size()); + + auto Iter = llvm::lower_bound( + LeafConstructTable, + static_cast>(RawLeafs.data()), + [](const auto *RowA, const auto *RowB) { + const auto *BeginA = &RowA[2]; + const auto *EndA = BeginA + static_cast(RowA[1]); + const auto *BeginB = &RowB[2]; + const auto *EndB = BeginB + static_cast(RowB[1]); + if (BeginA == EndA && BeginB == EndB) + return static_cast(RowA[0]) < static_cast(RowB[0]); + return std::lexicographical_compare(BeginA, EndA, BeginB, EndB); + }); + + if (Iter == std::end(LeafConstructTable)) + return OMPD_unknown; + + // Verify that we got a match. + Directive Found = (*Iter)[0]; + ArrayRef FoundLeafs = getLeafConstructs(Found); + if (FoundLeafs == GivenLeafs) + return Found; + return OMPD_unknown; +} +} // namespace llvm::omp diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td index 3184f625ead92..e6150210e7e9a 100644 --- a/llvm/test/TableGen/directive1.td +++ b/llvm/test/TableGen/directive1.td @@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-EMPTY: // CHECK-NEXT: #include "llvm/ADT/ArrayRef.h" // CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h" +// CHECK-NEXT: #include // CHECK-EMPTY: // CHECK-NEXT: namespace llvm { // CHECK-NEXT: class StringRef; @@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version. // CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version); // CHECK-EMPTY: -// CHECK-NEXT: llvm::ArrayRef getLeafConstructs(Directive D); +// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; } // CHECK-NEXT: Association getDirectiveAssociation(Directive D); // CHECK-NEXT: AKind getAKind(StringRef); // CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind); @@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind"); // IMPL-NEXT: } // IMPL-EMPTY: -// IMPL-NEXT: llvm::ArrayRef llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) { -// IMPL-NEXT: switch (Dir) { -// IMPL-NEXT: default: -// IMPL-NEXT: return ArrayRef{}; -// IMPL-NEXT: } // switch (Dir) -// IMPL-NEXT: } -// IMPL-EMPTY: // IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) { // IMPL-NEXT: switch (Dir) { // IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira: @@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Unexpected directive"); // IMPL-NEXT: } // IMPL-EMPTY: +// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int)); +// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = { +// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast(0), +// IMPL-NEXT: }; +// IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { +// IMPL-NEXT: 0, +// IMPL-NEXT: }; +// IMPL-EMPTY: // IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td index d6fa4835c8dfd..1750022e1f94e 100644 --- a/llvm/test/TableGen/directive2.td +++ b/llvm/test/TableGen/directive2.td @@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: #define LLVM_Tdl_INC // CHECK-EMPTY: // CHECK-NEXT: #include "llvm/ADT/ArrayRef.h" +// CHECK-NEXT: #include // CHECK-EMPTY: // CHECK-NEXT: namespace llvm { // CHECK-NEXT: class StringRef; @@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version. // CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version); // CHECK-EMPTY: -// CHECK-NEXT: llvm::ArrayRef getLeafConstructs(Directive D); +// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; } // CHECK-NEXT: Association getDirectiveAssociation(Directive D); // CHECK-NEXT: } // namespace tdl // CHECK-NEXT: } // namespace llvm @@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind"); // IMPL-NEXT: } // IMPL-EMPTY: -// IMPL-NEXT: llvm::ArrayRef llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) { -// IMPL-NEXT: switch (Dir) { -// IMPL-NEXT: default: -// IMPL-NEXT: return ArrayRef{}; -// IMPL-NEXT: } // switch (Dir) -// IMPL-NEXT: } -// IMPL-EMPTY: // IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) { // IMPL-NEXT: switch (Dir) { // IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira: @@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Unexpected directive"); // IMPL-NEXT: } // IMPL-EMPTY: +// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int)); +// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = { +// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast(0), +// IMPL-NEXT: }; +// IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { +// IMPL-NEXT: 0, +// IMPL-NEXT: }; +// IMPL-EMPTY: // IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt index c6f60142d6276..ddb6a16cbb984 100644 --- a/llvm/unittests/Frontend/CMakeLists.txt +++ b/llvm/unittests/Frontend/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests OpenMPContextTest.cpp OpenMPIRBuilderTest.cpp OpenMPParsingTest.cpp + OpenMPComposeTest.cpp DEPENDS acc_gen diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp new file mode 100644 index 0000000000000..29b1be4eb3432 --- /dev/null +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -0,0 +1,41 @@ +//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Frontend/OpenMP/OMP.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::omp; + +TEST(Composition, GetLeafConstructs) { + ArrayRef L1 = getLeafConstructs(OMPD_loop); + ASSERT_EQ(L1, (ArrayRef{})); + ArrayRef L2 = getLeafConstructs(OMPD_parallel_for); + ASSERT_EQ(L2, (ArrayRef{OMPD_parallel, OMPD_for})); + ArrayRef L3 = getLeafConstructs(OMPD_parallel_for_simd); + ASSERT_EQ(L3, (ArrayRef{OMPD_parallel, OMPD_for, OMPD_simd})); +} + +TEST(Composition, GetCompoundConstruct) { + Directive C1 = + getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute}); + ASSERT_EQ(C1, OMPD_target_teams_distribute); + Directive C2 = getCompoundConstruct({OMPD_target}); + ASSERT_EQ(C2, OMPD_target); + Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked}); + ASSERT_EQ(C3, OMPD_unknown); + Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C4, OMPD_target_teams_distribute); + Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C5, OMPD_target_teams_distribute); + Directive C6 = getCompoundConstruct({}); + ASSERT_EQ(C6, OMPD_unknown); + Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); + ASSERT_EQ(C7, OMPD_parallel_for_simd); +} diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index e0edf1720f8ac..2d2b774849189 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -20,6 +20,9 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include +#include + using namespace llvm; namespace { @@ -39,7 +42,8 @@ class IfDefScope { }; } // namespace -// Generate enum class +// Generate enum class. Entries are emitted in the order in which they appear +// in the `Records` vector. static void GenerateEnumClass(const std::vector &Records, raw_ostream &OS, StringRef Enum, StringRef Prefix, const DirectiveLanguage &DirLang, @@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const { return HasDuplicateClausesInDirectives(getDirectives()); } +// Count the maximum number of leaf constituents per construct. +static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) { + size_t MaxCount = 0; + for (Record *R : DirLang.getDirectives()) { + size_t Count = Directive{R}.getLeafConstructs().size(); + MaxCount = std::max(MaxCount, Count); + } + return MaxCount; +} + // Generate the declaration section for the enumeration in the directive // language static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { @@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include \n"; // for size_t OS << "\n"; OS << "namespace llvm {\n"; OS << "class StringRef;\n"; @@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; - OS << "llvm::ArrayRef getLeafConstructs(Directive D);\n"; + OS << "constexpr std::size_t getMaxLeafCount() { return " + << GetMaxLeafCount(DirLang) << "; }\n"; OS << "Association getDirectiveAssociation(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; @@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector &Clauses, } } +static std::string GetDirectiveName(const DirectiveLanguage &DirLang, + const Record *Rec) { + Directive Dir{Rec}; + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" + + DirLang.getDirectivePrefix() + Dir.getFormattedName()) + .str(); +} + +static std::string GetDirectiveType(const DirectiveLanguage &DirLang) { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive") + .str(); +} + // Generate the isAllowedClauseForDirective function implementation. static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, raw_ostream &OS) { @@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } -// Generate the getLeafConstructs function implementation. -static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, - raw_ostream &OS) { - auto getQualifiedName = [&](StringRef Formatted) -> std::string { - return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + - "::Directive::" + DirLang.getDirectivePrefix() + Formatted) - .str(); - }; - - // For each list of leaves, generate a static local object, then - // return a reference to that object for a given directive, e.g. +static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, + StringRef TableName) { + // The leaf constructs are emitted in a form of a 2D table, where each + // row corresponds to a directive (and there is a row for each directive). // - // static ListTy leafConstructs_A_B = { A, B }; - // static ListTy leafConstructs_C_D_E = { C, D, E }; - // switch (Dir) { - // case A_B: - // return leafConstructs_A_B; - // case C_D_E: - // return leafConstructs_C_D_E; - // } - - // Map from a record that defines a directive to the name of the - // local object with the list of its leaves. - DenseMap ListNames; - - std::string DirectiveTypeName = - std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; - - OS << '\n'; - - // ArrayRef<...> llvm::::GetLeafConstructs(llvm::::Directive Dir) - OS << "llvm::ArrayRef<" << DirectiveTypeName - << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" - << DirectiveTypeName << " Dir) "; - OS << "{\n"; - - // Generate the locals. - for (Record *R : DirLang.getDirectives()) { - Directive Dir{R}; + // Each row consists of + // - the id of the directive itself, + // - number of leaf constructs that will follow (0 for leafs), + // - ids of the leaf constructs (none if the directive is itself a leaf). + // The total number of these entries is at most MaxLeafCount+2. If this + // number is less than that, it is padded to occupy exactly MaxLeafCount+2 + // entries in memory. + // + // The rows are stored in the table in the lexicographical order. This + // is intended to enable binary search when mapping a sequence of leafs + // back to the compound directive. + // The consequence of that is that in order to find a row corresponding + // to the given directive, we'd need to scan the first element of each + // row. To avoid this, an auxiliary ordering table is created, such that + // row for Dir_A = table[auxiliary[Dir_A]]. + + std::vector Directives = DirLang.getDirectives(); + DenseMap DirId; // Record * -> llvm::omp::Directive + + for (auto [Idx, Rec] : llvm::enumerate(Directives)) + DirId.insert(std::make_pair(Rec, Idx)); + + using LeafList = std::vector; + int MaxLeafCount = GetMaxLeafCount(DirLang); + + // The initial leaf table, rows order is same as directive order. + std::vector LeafTable(Directives.size()); + for (auto [Idx, Rec] : llvm::enumerate(Directives)) { + Directive Dir{Rec}; + std::vector Leaves = Dir.getLeafConstructs(); + + auto &List = LeafTable[Idx]; + List.resize(MaxLeafCount + 2); + List[0] = Idx; // The id of the directive itself. + List[1] = Leaves.size(); // The number of leaves to follow. + + for (int I = 0; I != MaxLeafCount; ++I) + List[I + 2] = + static_cast(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1; + } - std::vector LeafConstructs = Dir.getLeafConstructs(); - if (LeafConstructs.empty()) - continue; + // Avoid sorting the vector array, instead sort an index array. + // It will also be useful later to create the auxiliary indexing array. + std::vector Ordering(Directives.size()); + std::iota(Ordering.begin(), Ordering.end(), 0); + + llvm::sort(Ordering, [&](int A, int B) { + auto &LeavesA = LeafTable[A]; + auto &LeavesB = LeafTable[B]; + if (LeavesA[1] == 0 && LeavesB[1] == 0) + return LeavesA[0] < LeavesB[0]; + return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1], + &LeavesB[2], &LeavesB[2] + LeavesB[1]); + }); - std::string ListName = "leafConstructs_" + Dir.getFormattedName(); - OS << " static const " << DirectiveTypeName << ' ' << ListName - << "[] = {\n"; - for (Record *L : LeafConstructs) { - Directive LeafDir{L}; - OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + // Emit the table + + // The directives are emitted into a scoped enum, for which the underlying + // type is `int` (by default). The code above uses `int` to store directive + // ids, so make sure that we catch it when something changes in the + // underlying type. + std::string DirectiveType = GetDirectiveType(DirLang); + OS << "\nstatic_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n"; + + OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName + << "[][" << MaxLeafCount + 2 << "] = {\n"; + for (size_t I = 0, E = Directives.size(); I != E; ++I) { + auto &Leaves = LeafTable[Ordering[I]]; + OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]); + OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),"; + for (size_t I = 2, E = Leaves.size(); I != E; ++I) { + int Idx = Leaves[I]; + if (Idx >= 0) + OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ','; + else + OS << " static_cast<" << DirectiveType << ">(-1),"; } - OS << " };\n"; - ListNames.insert(std::make_pair(R, std::move(ListName))); - } - - if (!ListNames.empty()) OS << '\n'; - OS << " switch (Dir) {\n"; - for (Record *R : DirLang.getDirectives()) { - auto F = ListNames.find(R); - if (F == ListNames.end()) - continue; - - Directive Dir{R}; - OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; - OS << " return " << F->second << ";\n"; } - OS << " default:\n"; - OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n"; - OS << " } // switch (Dir)\n"; - OS << "}\n"; + OS << "};\n\n"; + + // Emit the auxiliary index table: it's the inverse of the `Ordering` + // table above. + OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n"; + OS << " "; + std::vector Reverse(Ordering.size()); + for (int I = 0, E = Ordering.size(); I != E; ++I) + Reverse[Ordering[I]] = I; + for (int Idx : Reverse) + OS << ' ' << Idx << ','; + OS << "\n};\n"; } static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang, @@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); - // getLeafConstructs(Directive D) - GenerateGetLeafConstructs(DirLang, OS); - // getDirectiveAssociation(Directive D) GenerateGetDirectiveAssociation(DirLang, OS); + + // Leaf table for getLeafConstructs, etc. + EmitLeafTable(DirLang, OS, "LeafConstructTable"); } // Generate the implemenation section for the enumeration in the directive From 0d92781c7a52ed2fbab33ae6e7b3dae61cfd42ae Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 2 Apr 2024 08:20:15 -0500 Subject: [PATCH 07/13] Address review comments --- llvm/lib/Frontend/OpenMP/OMP.cpp | 10 ++++++++-- llvm/unittests/Frontend/OpenMPComposeTest.cpp | 10 ++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index dd99d3d074fd1..7504c9076fde1 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -27,8 +27,8 @@ using namespace llvm::omp; namespace llvm::omp { ArrayRef getLeafConstructs(Directive D) { - auto Idx = static_cast(D); - if (Idx < 0 || Idx >= static_cast(Directive_enumSize)) + auto Idx = static_cast(D); + if (Idx >= Directive_enumSize) return {}; const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; return ArrayRef(&Row[2], &Row[2] + static_cast(Row[1])); @@ -50,6 +50,12 @@ Directive getCompoundConstruct(ArrayRef Parts) { RawLeafs.push_back(P); } + // RawLeafs will be used as key in the binary search. The search doesn't + // guarantee that the exact same entry will be found (since RawLeafs may + // not correspond to any compound directive). Because of that, we will + // need to compare the search result with the given set of leafs. + // Also, if there is only one leaf in the list, it corresponds to itself, + // no search is necessary. auto GivenLeafs{ArrayRef(RawLeafs).drop_front(2)}; if (GivenLeafs.size() == 1) return GivenLeafs.front(); diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp index 29b1be4eb3432..c3e0880ece864 100644 --- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -32,10 +32,8 @@ TEST(Composition, GetCompoundConstruct) { ASSERT_EQ(C3, OMPD_unknown); Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); ASSERT_EQ(C4, OMPD_target_teams_distribute); - Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); - ASSERT_EQ(C5, OMPD_target_teams_distribute); - Directive C6 = getCompoundConstruct({}); - ASSERT_EQ(C6, OMPD_unknown); - Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); - ASSERT_EQ(C7, OMPD_parallel_for_simd); + Directive C5 = getCompoundConstruct({}); + ASSERT_EQ(C5, OMPD_unknown); + Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); + ASSERT_EQ(C6, OMPD_parallel_for_simd); } From 46770f8dfe25528e970e5908aae8b2a788655bfc Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 29 Mar 2024 09:20:41 -0500 Subject: [PATCH 08/13] [flang][OpenMP] Move clause/object conversion to happen early, in genOMP This removes the last use of genOmpObjectList2, which has now been removed. --- flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +- flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++--------- flang/lib/Lower/OpenMP/Utils.cpp | 30 +- flang/lib/Lower/OpenMP/Utils.h | 6 +- 5 files changed, 218 insertions(+), 252 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index db7a1b8335f81..f4d659b70cfee 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -49,9 +49,8 @@ class ClauseProcessor { public: ClauseProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses) - : converter(converter), semaCtx(semaCtx), - clauses(makeClauses(clauses, semaCtx)) {} + const List &clauses) + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} // 'Unique' clauses: They can appear at most once in the clause list. bool processCollapse( diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index c11ee299c5d08..ef7b14327278e 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -78,13 +78,12 @@ class DataSharingProcessor { public: DataSharingProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &opClauseList, + const List &clauses, Fortran::lower::pft::Evaluation &eval, bool useDelayedPrivatization = false, Fortran::lower::SymMap *symTable = nullptr) : hasLastPrivateOp(false), converter(converter), - firOpBuilder(converter.getFirOpBuilder()), - clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval), + firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval), useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {} // Privatisation is split into two steps. diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index edae453972d3d..23dc25ac1ae9a 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -17,6 +17,7 @@ #include "DataSharingProcessor.h" #include "DirectivesCommon.h" #include "ReductionProcessor.h" +#include "Utils.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" @@ -310,14 +311,15 @@ static void getDeclareTargetInfo( } else if (const auto *clauseList{ Fortran::parser::Unwrap( spec.u)}) { - if (clauseList->v.empty()) { + List clauses = makeClauses(*clauseList, semaCtx); + if (clauses.empty()) { // Case: declare target, implicit capture of function symbolAndClause.emplace_back( mlir::omp::DeclareTargetCaptureClause::to, eval.getOwningProcedure()->getSubprogramSymbol()); } - ClauseProcessor cp(converter, semaCtx, *clauseList); + ClauseProcessor cp(converter, semaCtx, clauses); cp.processDeviceType(clauseOps); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); @@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { // TODO: Generate the reduction operation during lowering instead of creating // and removing operations since this is not a robust approach. Also, removing // ops in the builder (instead of a rewriter) is probably not the best approach. -static void -genOpenMPReduction(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { +static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - List clauses{makeClauses(clauseList, semaCtx)}; - for (const Clause &clause : clauses) { if (const auto &reductionClause = std::get_if(&clause.u)) { @@ -812,7 +811,7 @@ struct OpWithBodyGenInfo { return *this; } - OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) { + OpWithBodyGenInfo &setClauses(const List *value) { clauses = value; return *this; } @@ -848,7 +847,7 @@ struct OpWithBodyGenInfo { /// [in] is this an outer operation - prevents privatization. bool outerCombined = false; /// [in] list of clauses to process. - const Fortran::parser::OmpClauseList *clauses = nullptr; + const List *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; /// [in] if provided, list of reduction symbols @@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { // Code generation functions for clauses //===----------------------------------------------------------------------===// -static void genCriticalDeclareClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { +static void +genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, + llvm::StringRef name) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processHint(clauseOps); clauseOps.nameAttr = mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); } -static void genFlushClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const std::optional &objects, - const std::optional> - &clauses, - mlir::Location loc, llvm::SmallVectorImpl &operandRange) { - if (objects) - genObjectList2(*objects, converter, operandRange); - - if (clauses && clauses->size() > 0) +static void genFlushClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const ObjectList &objects, + const List &clauses, mlir::Location loc, + llvm::SmallVectorImpl &operandRange) { + genObjectList(objects, converter, operandRange); + + if (clauses.size() > 0) TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); } static void genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::OrderedRegionClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO(loc, llvm::omp::Directive::OMPD_ordered); @@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, static void genParallelClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processReduction, + mlir::omp::ParallelClauseOps &clauseOps, llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1286,8 +1282,7 @@ static void genParallelClauses( static void genSectionsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, bool clausesFromBeginSections, mlir::omp::SectionsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1304,9 +1299,8 @@ static void genSimdLoopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::SimdLoopClauseOps &clauseOps, + Fortran::lower::pft::Evaluation &eval, const List &clauses, + mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processCollapse(loc, eval, clauseOps, iv); @@ -1324,9 +1318,8 @@ static void genSimdLoopClauses( static void genSingleClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList &endClauses, - mlir::Location loc, + const List &beginClauses, + const List &endClauses, mlir::Location loc, mlir::omp::SingleClauseOps &clauseOps) { ClauseProcessor bcp(converter, semaCtx, beginClauses); bcp.processAllocate(clauseOps); @@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter, static void genTargetClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processHostOnlyClauses, bool processReduction, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processHostOnlyClauses, bool processReduction, mlir::omp::TargetClauseOps &clauseOps, llvm::SmallVectorImpl &mapSyms, llvm::SmallVectorImpl &mapSymLocs, @@ -1368,9 +1360,8 @@ static void genTargetClauses( static void genTargetDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::TargetDataClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) { @@ -1401,9 +1392,8 @@ static void genTargetDataClauses( static void genTargetEnterExitUpdateDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - llvm::omp::Directive directive, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, llvm::omp::Directive directive, mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); @@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses( static void genTaskClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter, static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskgroupClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TaskwaitClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO( @@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, static void genTeamsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List &clauses, mlir::Location loc, mlir::omp::TeamsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1482,9 +1468,8 @@ static void genWsloopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc, + Fortran::lower::pft::Evaluation &eval, const List &beginClauses, + const List &endClauses, mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps, llvm::SmallVectorImpl &iv, llvm::SmallVectorImpl &reductionTypes, @@ -1501,8 +1486,8 @@ static void genWsloopClauses( if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - if (endClauses) { - ClauseProcessor ecp(converter, semaCtx, *endClauses); + if (!endClauses.empty()) { + ClauseProcessor ecp(converter, semaCtx, endClauses); ecp.processNowait(clauseOps); } @@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp genCriticalOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, const std::optional &name) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::FlatSymbolRefAttr nameAttr; @@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter, auto global = mod.lookupSymbol(nameStr); if (!global) { mlir::omp::CriticalClauseOps clauseOps; - genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps, + genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps, nameStr); mlir::OpBuilder modBuilder(mod.getBodyRegion()); @@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp genDistributeOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { TODO(loc, "Distribute construct"); return nullptr; } @@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp genFlushOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const std::optional &objectList, - const std::optional> - &clauseList) { + const ObjectList &objects, const List &clauses) { llvm::SmallVector operandRange; - genFlushClauses(converter, semaCtx, objectList, clauseList, loc, - operandRange); + genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange); return converter.getFirOpBuilder().create( converter.getCurrentLocation(), operandRange); @@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp genOrderedOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { TODO(loc, "OMPD_ordered"); return nullptr; } @@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { mlir::omp::OrderedRegionClauseOps clauseOps; - genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps); + genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), @@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector privateSyms; llvm::SmallVector reductionTypes; llvm::SmallVector reductionSyms; - genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc, /*processReduction=*/!outerCombined, clauseOps, reductionTypes, reductionSyms); @@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList) + .setClauses(&clauses) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(reductionCallback); @@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody(genInfo, clauseOps); bool privatize = !outerCombined; - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, + DataSharingProcessor dsp(converter, semaCtx, clauses, eval, /*useDelayedPrivatization=*/true, &symTable); if (privatize) @@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp genSectionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList)); + .setClauses(&clauses)); } static mlir::omp::SectionsOp @@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp genSimdLoopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval); + const List &clauses) { + DataSharingProcessor dsp(converter, semaCtx, clauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; mlir::omp::SimdLoopClauseOps clauseOps; llvm::SmallVector iv; - genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc, - clauseOps, iv); + genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps, + iv); - auto *nestedEval = - getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopVars(op, converter, loc, iv); @@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&clauseList) + .setClauses(&clauses) .setDataSharingProcessor(&dsp) .setGenRegionEntryCb(ivCallback), clauseOps); @@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp genSingleOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList &endClauseList) { + mlir::Location loc, const List &beginClauses, + const List &endClauses) { mlir::omp::SingleClauseOps clauseOps; - genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc, + genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&beginClauseList), + .setClauses(&beginClauses), clauseOps); } @@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector mapSyms; llvm::SmallVector mapSymLocs; llvm::SmallVector mapSymTypes; - genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc, processHostOnlyClauses, /*processReduction=*/outerCombined, clauseOps, mapSyms, mapSymLocs, mapSymTypes); @@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp genTargetDataOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TargetDataClauseOps clauseOps; llvm::SmallVector useDeviceTypes; llvm::SmallVector useDeviceLocs; llvm::SmallVector useDeviceSyms; - genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps, + genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); auto targetDataOp = @@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, return targetDataOp; } -template -static OpTy genTargetEnterExitUpdateDataOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { +template static OpTy +genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + mlir::Location loc, + const List &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp( } mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; - genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList, - loc, directive, clauseOps); + genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc, + directive, clauseOps); return firOpBuilder.create(loc, clauseOps); } @@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp genTaskOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TaskClauseOps clauseOps; - genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp genTaskgroupOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List &clauses) { mlir::omp::TaskgroupClauseOps clauseOps; - genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp genTaskloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { TODO(loc, "Taskloop construct"); } @@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp genTaskwaitOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List &clauses) { mlir::omp::TaskwaitClauseOps clauseOps; - genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps); return converter.getFirOpBuilder().create(loc, clauseOps); } @@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp genTeamsOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List &clauses, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TeamsClauseOps clauseOps; - genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp genWsloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList) { - DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); + const List &beginClauses, const List &endClauses) { + DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; @@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector iv; llvm::SmallVector reductionTypes; llvm::SmallVector reductionSyms; - genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList, - endClauseList, loc, clauseOps, iv, reductionTypes, - reductionSyms); + genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses, + loc, clauseOps, iv, reductionTypes, reductionSyms); - auto *nestedEval = getCollapsedLoopEval( - eval, Fortran::lower::getCollapseValue(beginClauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms, @@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&beginClauseList) + .setClauses(&beginClauses) .setDataSharingProcessor(&dsp) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(ivCallback), @@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO"); } @@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD"); } @@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE SIMD"); } @@ -2068,10 +2036,10 @@ static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List &beginClauses, + const List &endClauses, mlir::Location loc) { - ClauseProcessor cp(converter, semaCtx, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauses); cp.processTODO( loc, llvm::omp::OMPD_do_simd); @@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); + genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses); } static void genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite TASKLOOP SIMD"); } @@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const auto &directive = std::get( simpleStandaloneConstruct.t); - const auto &clauseList = - std::get(simpleStandaloneConstruct.t); + List clauses = makeClauses( + std::get(simpleStandaloneConstruct.t), + semaCtx); mlir::Location currentLocation = converter.genLocation(directive.source); switch (directive.v) { @@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, genBarrierOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_taskwait: - genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList); + genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_taskyield: genTaskyieldOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_target_data: genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, clauseList); + currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_enter_data: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_exit_data: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_update: genTargetEnterExitUpdateDataOp( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_ordered: - genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList); + genOrderedOp(converter, semaCtx, eval, currentLocation, clauses); break; } } @@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, const auto &clauseList = std::get>>( flushConstruct.t); + ObjectList objects = + objectList ? makeObjects(*objectList, semaCtx) : ObjectList{}; + List clauses = + clauseList ? makeList(*clauseList, + [&](auto &&s) { return makeClause(s.v, semaCtx); }) + : List{}; mlir::Location currentLocation = converter.genLocation(verbatim.source); - genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList); + genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses); } static void @@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, converter.genLocation(beginBlockDirective.source); const auto origDirective = std::get(beginBlockDirective.t).v; - const auto &beginClauseList = - std::get(beginBlockDirective.t); - const auto &endClauseList = - std::get(endBlockDirective.t); + List beginClauses = makeClauses( + std::get(beginBlockDirective.t), semaCtx); + List endClauses = makeClauses( + std::get(endBlockDirective.t), semaCtx); assert(llvm::omp::blockConstructSet.test(origDirective) && "Expected block construct"); - for (const Fortran::parser::OmpClause &clause : beginClauseList.v) { + for (const Clause &clause : beginClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u) && - !std::get_if(&clause.u)) { + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u) && + !std::get_if(&clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } - for (const auto &clause : endClauseList.v) { + for (const Clause &clause : endClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if(&clause.u) && - !std::get_if(&clause.u)) + if (!std::get_if(&clause.u) && + !std::get_if(&clause.u)) TODO(clauseLocation, "OpenMP Block construct clause"); } @@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_ordered: // 2.17.9 ORDERED construct. genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, outerCombined); + currentLocation, beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_single: // 2.8.2 SINGLE construct. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, outerCombined); + beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_target_data: // 2.12.2 TARGET DATA construct. genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_task: // 2.10.1 TASK construct. genTaskOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_taskgroup: // 2.17.6 TASKGROUP construct. genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. // FIXME Pass the outerCombined argument or rename it to better describe // what it represents if it must always be `false` in this context. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_workshare: // 2.8.3 WORKSHARE construct. @@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, // implementation for this feature will come later. For the codes // that use this construct, add a single construct for now. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; default: llvm_unreachable("Unexpected block construct"); @@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { const auto &cd = std::get(criticalConstruct.t); - const auto &clauseList = std::get(cd.t); + List clauses = + makeClauses(std::get(cd.t), semaCtx); const auto &name = std::get>(cd.t); mlir::Location currentLocation = converter.getCurrentLocation(); genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - clauseList, name); + clauses, name); } static void @@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); - const auto &beginClauseList = - std::get(beginLoopDirective.t); + List beginClauses = makeClauses( + std::get(beginLoopDirective.t), semaCtx); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); const auto origDirective = @@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, assert(llvm::omp::loopConstructSet.test(origDirective) && "Expected loop construct"); - const auto *endClauseList = [&]() { - using RetTy = const Fortran::parser::OmpClauseList *; + List endClauses = [&]() { if (auto &endLoopDirective = std::get>( loopConstruct.t)) { - return RetTy( - &std::get((*endLoopDirective).t)); + return makeClauses( + std::get(endLoopDirective->t), + semaCtx); } - return RetTy(); + return List{}; }(); std::optional nextDir = origDirective; @@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute_parallel_do: // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct. genCompositeDistributeParallelDo(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_parallel_do_simd: // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct. genCompositeDistributeParallelDoSimd(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_simd: // 2.9.4.2 DISTRIBUTE SIMD construct. - genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_do_simd: // 2.9.3.2 Worksharing-Loop SIMD construct. - genCompositeDoSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDoSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_taskloop_simd: // 2.10.3 TASKLOOP SIMD construct. - genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; default: llvm_unreachable("Unexpected composite construct"); @@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute: // 2.9.4.1 DISTRIBUTE construct. genDistributeOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_do: // 2.9.2 Worksharing-Loop construct. - genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, - endClauseList); + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses, + endClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. @@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, + currentLocation, beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_simd: // 2.9.3.1 SIMD construct. genSimdLoopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); - genOpenMPReduction(converter, semaCtx, beginClauseList); + beginClauses); + genOpenMPReduction(converter, semaCtx, beginClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_taskloop: // 2.10.2 TASKLOOP construct. genTaskloopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. @@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_loop: case llvm::omp::Directive::OMPD_masked: @@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { const auto &beginSectionsDirective = std::get(sectionsConstruct.t); - const auto &beginClauseList = - std::get(beginSectionsDirective.t); + List beginClauses = makeClauses( + std::get(beginSectionsDirective.t), + semaCtx); // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region mlir::Location currentLocation = converter.getCurrentLocation(); mlir::omp::SectionsClauseOps clauseOps; - genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation, + genSectionsClauses(converter, semaCtx, beginClauses, currentLocation, /*clausesFromBeginSections=*/true, clauseOps); // Parallel wrapper of PARALLEL SECTIONS construct @@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, .v; if (dir == llvm::omp::Directive::OMPD_parallel_sections) { genParallelOp(converter, symTable, semaCtx, eval, - /*genNested=*/false, currentLocation, beginClauseList, + /*genNested=*/false, currentLocation, beginClauses, /*outerCombined=*/true); } else { const auto &endSectionsDirective = std::get(sectionsConstruct.t); - const auto &endClauseList = - std::get(endSectionsDirective.t); - genSectionsClauses(converter, semaCtx, endClauseList, currentLocation, + List endClauses = makeClauses( + std::get(endSectionsDirective.t), + semaCtx); + genSectionsClauses(converter, semaCtx, endClauses, currentLocation, /*clausesFromBeginSections=*/false, clauseOps); } @@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, - beginClauseList); + beginClauses); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); } diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b9c0660aa4da8..da3f2be73e509 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -36,6 +36,17 @@ namespace Fortran { namespace lower { namespace omp { +int64_t getCollapseValue(const List &clauses) { + auto iter = llvm::find_if(clauses, [](const Clause &clause) { + return clause.id == llvm::omp::Clause::OMPC_collapse; + }); + if (iter != clauses.end()) { + const auto &collapse = std::get(iter->u); + return evaluate::ToInt64(collapse.v).value(); + } + return 1; +} + void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands) { @@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects, } } -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); - } else if (const auto *details = - sym->detailsIf()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } - }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); - } -} - mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, std::size_t loopVarTypeSize) { // OpenMP runtime requires 32-bit or 64-bit loop variables. diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 4074bf73987d5..b3a9f7f30c98b 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -58,6 +58,8 @@ void gatherFuncAndVarSyms( const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl &symbolAndClause); +int64_t getCollapseValue(const List &clauses); + Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); @@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl &operands); -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl &operands); - } // namespace omp } // namespace lower } // namespace Fortran From 065b54c4ddf2b356333269aecbee00b5a23ca1ea Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 17 Apr 2024 09:37:17 -0500 Subject: [PATCH 09/13] clang-format --- flang/lib/Lower/OpenMP/OpenMP.cpp | 92 +++++++++++++++---------------- 1 file changed, 43 insertions(+), 49 deletions(-) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 6a7699aee5931..4424788e0132e 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1022,22 +1022,23 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { // Code generation functions for clauses //===----------------------------------------------------------------------===// -static void genCriticalDeclareClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const List &clauses, mlir::Location loc, - mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { +static void +genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, + llvm::StringRef name) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processHint(clauseOps); clauseOps.nameAttr = mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); } -static void genFlushClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const ObjectList &objects, const List &clauses, - mlir::Location loc, llvm::SmallVectorImpl &operandRange) { +static void genFlushClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const ObjectList &objects, + const List &clauses, mlir::Location loc, + llvm::SmallVectorImpl &operandRange) { if (!objects.empty()) genObjectList(objects, converter, operandRange); @@ -1048,9 +1049,8 @@ static void genFlushClauses( static void genLoopNestClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const List &clauses, mlir::Location loc, - mlir::omp::LoopNestClauseOps &clauseOps, + Fortran::lower::pft::Evaluation &eval, const List &clauses, + mlir::Location loc, mlir::omp::LoopNestClauseOps &clauseOps, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processCollapse(loc, eval, clauseOps, iv); @@ -1069,9 +1069,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, static void genParallelClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const List &clauses, mlir::Location loc, - bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processReduction, + mlir::omp::ParallelClauseOps &clauseOps, llvm::SmallVectorImpl &reductionTypes, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1136,9 +1136,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter, static void genTargetClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const List &clauses, mlir::Location loc, - bool processHostOnlyClauses, bool processReduction, + Fortran::lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool processHostOnlyClauses, bool processReduction, mlir::omp::TargetClauseOps &clauseOps, llvm::SmallVectorImpl &mapSyms, llvm::SmallVectorImpl &mapLocs, @@ -1708,10 +1707,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, } template -static OpTy genTargetEnterExitUpdateDataOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, - const List &clauses) { +static OpTy +genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + mlir::Location loc, + const List &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1852,8 +1852,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, static void genCompositeDistributeParallelDo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const List &beginClauses, + Fortran::lower::pft::Evaluation &eval, const List &beginClauses, const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO"); } @@ -1861,28 +1860,26 @@ static void genCompositeDistributeParallelDo( static void genCompositeDistributeParallelDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const List &beginClauses, + Fortran::lower::pft::Evaluation &eval, const List &beginClauses, const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD"); } -static void genCompositeDistributeSimd( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const List &beginClauses, - const List &endClauses, mlir::Location loc) { +static void +genCompositeDistributeSimd(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const List &beginClauses, + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE SIMD"); } -static void -genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const List &beginClauses, - const List &endClauses, - mlir::Location loc) { +static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const List &beginClauses, + const List &endClauses, + mlir::Location loc) { ClauseProcessor cp(converter, semaCtx, beginClauses); cp.processTODO( @@ -1903,8 +1900,7 @@ genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const List &beginClauses, - const List &endClauses, - mlir::Location loc) { + const List &endClauses, mlir::Location loc) { TODO(loc, "Composite TASKLOOP SIMD"); } @@ -2351,9 +2347,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, switch (leafDir) { case llvm::omp::Directive::OMPD_distribute_parallel_do: // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct. - genCompositeDistributeParallelDo(converter, semaCtx, eval, - beginClauses, endClauses, - currentLocation); + genCompositeDistributeParallelDo(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_parallel_do_simd: // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct. @@ -2368,8 +2363,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_do_simd: // 2.9.3.2 Worksharing-Loop SIMD construct. - genCompositeDoSimd(converter, semaCtx, eval, beginClauses, - endClauses, currentLocation); + genCompositeDoSimd(converter, semaCtx, eval, beginClauses, endClauses, + currentLocation); break; case llvm::omp::Directive::OMPD_taskloop_simd: // 2.10.3 TASKLOOP SIMD construct. @@ -2413,8 +2408,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_taskloop: // 2.10.2 TASKLOOP construct. - genTaskloopOp(converter, semaCtx, eval, currentLocation, - beginClauses); + genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. From 883043d931f3e3ce899c897b49e6400d5e25419d Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 18 Apr 2024 17:06:19 -0500 Subject: [PATCH 10/13] Address issue with "end directives" --- llvm/lib/Frontend/OpenMP/OMP.cpp | 4 +- llvm/unittests/Frontend/OpenMPComposeTest.cpp | 2 + llvm/utils/TableGen/DirectiveEmitter.cpp | 38 ++++++++++++++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index 7504c9076fde1..1bf5f5e96ba65 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -61,8 +61,8 @@ Directive getCompoundConstruct(ArrayRef Parts) { return GivenLeafs.front(); RawLeafs[1] = static_cast(GivenLeafs.size()); - auto Iter = llvm::lower_bound( - LeafConstructTable, + auto Iter = std::lower_bound( + LeafConstructTable, LeafConstructTableEndDirective, static_cast>(RawLeafs.data()), [](const auto *RowA, const auto *RowB) { const auto *BeginA = &RowA[2]; diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp index c3e0880ece864..9a8a253ef1026 100644 --- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -36,4 +36,6 @@ TEST(Composition, GetCompoundConstruct) { ASSERT_EQ(C5, OMPD_unknown); Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); ASSERT_EQ(C6, OMPD_parallel_for_simd); + Directive C7 = getCompoundConstruct({OMPD_do, OMPD_simd}); + ASSERT_EQ(C7, OMPD_do_simd); // Make sure it's not OMPD_end_do_simd } diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index 2d2b774849189..20fac6ac0ea4c 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/TableGen/DirectiveEmitter.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -501,7 +503,7 @@ static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, // row for Dir_A = table[auxiliary[Dir_A]]. std::vector Directives = DirLang.getDirectives(); - DenseMap DirId; // Record * -> llvm::omp::Directive + DenseMap DirId; // Record * -> llvm::omp::Directive for (auto [Idx, Rec] : llvm::enumerate(Directives)) DirId.insert(std::make_pair(Rec, Idx)); @@ -525,6 +527,25 @@ static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, static_cast(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1; } + // Some Fortran directives are delimited, i.e. they have the form of + // "directive"---"end directive". If "directive" is a compound construct, + // then the set of leaf constituents will be nonempty and the same for + // both directives. Given this set of leafs, looking up the corresponding + // compound directive should return "directive", and not "end directive". + // To avoid this problem, gather all "end directives" at the end of the + // leaf table, and only do the search on the initial segment of the table + // that excludes the "end directives". + // It's safe to find all directives whose names begin with "end ". The + // problem only exists for compound directives, like "end do simd". + // All existing directives with names starting with "end " are either + // "end directives" for an existing "directive", or leaf directives + // (such as "end declare target"). + DenseSet EndDirectives; + for (auto [Rec, Id] : DirId) { + if (Directive{Rec}.getName().starts_with_insensitive("end ")) + EndDirectives.insert(Id); + } + // Avoid sorting the vector array, instead sort an index array. // It will also be useful later to create the auxiliary indexing array. std::vector Ordering(Directives.size()); @@ -533,8 +554,13 @@ static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, llvm::sort(Ordering, [&](int A, int B) { auto &LeavesA = LeafTable[A]; auto &LeavesB = LeafTable[B]; + int DirA = LeavesA[0], DirB = LeavesB[0]; + // First of all, end directives compare greater than non-end directives. + int IsEndA = EndDirectives.count(DirA), IsEndB = EndDirectives.count(DirB); + if (IsEndA != IsEndB) + return IsEndA < IsEndB; if (LeavesA[1] == 0 && LeavesB[1] == 0) - return LeavesA[0] < LeavesB[0]; + return DirA < DirB; return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1], &LeavesB[2], &LeavesB[2] + LeavesB[1]); }); @@ -565,6 +591,14 @@ static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, } OS << "};\n\n"; + // Emit a marker where the first "end directive" is. + auto FirstE = llvm::find_if(Ordering, [&](int RowIdx) { + return EndDirectives.count(LeafTable[RowIdx][0]); + }); + OS << "[[maybe_unused]] static auto " << TableName + << "EndDirective = " << TableName << " + " + << std::distance(Ordering.begin(), FirstE) << ";\n\n"; + // Emit the auxiliary index table: it's the inverse of the `Ordering` // table above. OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n"; From 810514d95b7e7dd9413d2883bf210fc5cd517533 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 19 Apr 2024 07:51:31 -0500 Subject: [PATCH 11/13] Fix tests --- llvm/test/TableGen/directive1.td | 2 ++ llvm/test/TableGen/directive2.td | 2 ++ 2 files changed, 4 insertions(+) diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td index e6150210e7e9a..526dcb3c3bf0a 100644 --- a/llvm/test/TableGen/directive1.td +++ b/llvm/test/TableGen/directive1.td @@ -373,6 +373,8 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast(0), // IMPL-NEXT: }; // IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static auto LeafConstructTableEndDirective = LeafConstructTable + 1; +// IMPL-EMPTY: // IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { // IMPL-NEXT: 0, // IMPL-NEXT: }; diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td index 1750022e1f94e..9df8a06d3e517 100644 --- a/llvm/test/TableGen/directive2.td +++ b/llvm/test/TableGen/directive2.td @@ -304,6 +304,8 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast(0), // IMPL-NEXT: }; // IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static auto LeafConstructTableEndDirective = LeafConstructTable + 1; +// IMPL-EMPTY: // IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { // IMPL-NEXT: 0, // IMPL-NEXT: }; From c7c02ec6431a3828da0a006b80a1ed2cb8ffd1b7 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 19 Apr 2024 07:59:50 -0500 Subject: [PATCH 12/13] clang-format --- llvm/unittests/Frontend/OpenMPComposeTest.cpp | 2 +- llvm/utils/TableGen/DirectiveEmitter.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp index 9a8a253ef1026..c5fbe6ec6adfe 100644 --- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -37,5 +37,5 @@ TEST(Composition, GetCompoundConstruct) { Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); ASSERT_EQ(C6, OMPD_parallel_for_simd); Directive C7 = getCompoundConstruct({OMPD_do, OMPD_simd}); - ASSERT_EQ(C7, OMPD_do_simd); // Make sure it's not OMPD_end_do_simd + ASSERT_EQ(C7, OMPD_do_simd); // Make sure it's not OMPD_end_do_simd } diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index 20fac6ac0ea4c..69d9c5e8325ab 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -417,8 +417,8 @@ GenerateCaseForVersionedClauses(const std::vector &Clauses, static std::string GetDirectiveName(const DirectiveLanguage &DirLang, const Record *Rec) { Directive Dir{Rec}; - return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" + - DirLang.getDirectivePrefix() + Dir.getFormattedName()) + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + + "::" + DirLang.getDirectivePrefix() + Dir.getFormattedName()) .str(); } From 19f06c853acfe039dfb9eca538be8f0815c0668c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Mon, 22 Apr 2024 09:28:42 -0500 Subject: [PATCH 13/13] Address review feedback --- llvm/lib/Frontend/OpenMP/OMP.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index 1bf5f5e96ba65..e958bced3a422 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -29,9 +29,9 @@ namespace llvm::omp { ArrayRef getLeafConstructs(Directive D) { auto Idx = static_cast(D); if (Idx >= Directive_enumSize) - return {}; + std::nullopt; const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; - return ArrayRef(&Row[2], &Row[2] + static_cast(Row[1])); + return ArrayRef(&Row[2], static_cast(Row[1])); } Directive getCompoundConstruct(ArrayRef Parts) { @@ -64,7 +64,7 @@ Directive getCompoundConstruct(ArrayRef Parts) { auto Iter = std::lower_bound( LeafConstructTable, LeafConstructTableEndDirective, static_cast>(RawLeafs.data()), - [](const auto *RowA, const auto *RowB) { + [](const llvm::omp::Directive *RowA, const llvm::omp::Directive *RowB) { const auto *BeginA = &RowA[2]; const auto *EndA = BeginA + static_cast(RowA[1]); const auto *BeginB = &RowB[2];