Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ createOMPEarlyOutliningPass();
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createOMPMarkDeclareTargetPass();
std::unique_ptr<mlir::Pass> createOMPLoopIndexMemToRegPass();

// declarative passes
#define GEN_PASS_REGISTRATION
Expand Down
9 changes: 9 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,13 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
];
}

def OMPLoopIndexMemToReg : Pass<"omp-loop-index-mem2reg", "mlir::func::FuncOp"> {
let summary = "Pushes allocations for index variables of OpenMP loops into "
"the loop region and, if they are never passed by reference, "
"they are replaced by the corresponding entry block arguments, "
"removing all redundant allocations in the process.";
let constructor = "::fir::createOMPLoopIndexMemToRegPass()";
let dependentDialects = ["fir::FIROpsDialect", "mlir::omp::OpenMPDialect"];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ inline void createOpenMPFIRPassPipeline(
pm.addPass(fir::createOMPEarlyOutliningPass());
pm.addPass(fir::createOMPFunctionFilteringPass());
}
pm.addPass(fir::createOMPLoopIndexMemToRegPass());
}

#if !defined(FLANG_EXCLUDE_CODEGEN)
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_flang_library(FIRTransforms
OMPEarlyOutlining.cpp
OMPFunctionFiltering.cpp
OMPMarkDeclareTarget.cpp
OMPLoopIndexMemToReg.cpp

DEPENDS
FIRDialect
Expand Down
250 changes: 250 additions & 0 deletions flang/lib/Optimizer/Transforms/OMPLoopIndexMemToReg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//===- OMPWsLoopIndexMem2Reg.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to push allocations into an OpenMP loop
// operation region when they are used to store loop indices. Then, they are
// removed together with any associated load or store operations if their
// address is not needed, in which case uses of their values are replaced for
// the block argument from which they were originally initialized.
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"

#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include <llvm/ADT/MapVector.h>
#include <llvm/ADT/SmallSet.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/Casting.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/ValueRange.h>
#include <mlir/Support/LLVM.h>

namespace fir {
#define GEN_PASS_DEF_OMPLOOPINDEXMEMTOREG
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

template <typename LoopOpTy>
class LoopProcessorHelper {
LoopOpTy loop;

bool allUsesInLoop(ValueRange stores) {
for (Value store : stores) {
for (OpOperand &use : store.getUses()) {
Operation *owner = use.getOwner();
if (owner->getParentOfType<LoopOpTy>() != loop.getOperation())
return false;
}
}
return true;
}

/// Check whether a given hlfir.declare known to only be used inside of the
/// loop and initialized by a fir.alloca operation also only used inside of
/// the loop can be removed and replaced by the block argument representing
/// the corresponding loop index.
static bool isDeclareRemovable(hlfir::DeclareOp declareOp) {
fir::AllocaOp allocaOp = llvm::dyn_cast_if_present<fir::AllocaOp>(
declareOp.getMemref().getDefiningOp());

// Check that the hlfir.declare is initialized by a fir.alloca that is only
// used as argument to that operation.
if (!allocaOp || !allocaOp.getResult().hasOneUse())
return false;

// Check that uses of the pointers can be replaced by the block argument.
for (OpOperand &use : declareOp.getOriginalBase().getUses()) {
Operation *owner = use.getOwner();
if (!isa<fir::StoreOp>(owner))
return false;
}
for (OpOperand &use : declareOp.getBase().getUses()) {
Operation *owner = use.getOwner();
if (!isa<fir::LoadOp>(owner))
return false;
}

return true;
}

/// Check whether a given fir.alloca known to only be used inside of the loop
/// can be removed and replaced by the block argument representing the
/// corresponding loop index.
static bool isAllocaRemovable(fir::AllocaOp allocaOp) {
// Check that uses of the pointer are all fir.load and fir.store.
for (OpOperand &use : allocaOp.getResult().getUses()) {
Operation *owner = use.getOwner();
if (!isa<fir::LoadOp>(owner) && !isa<fir::StoreOp>(owner))
return false;
}

return true;
}

/// Try to push an hlfir.declare operation defined outside of the loop inside,
/// if all uses of that operation and the corresponding fir.alloca are
/// contained inside of the loop.
LogicalResult pushDeclareIntoLoop(hlfir::DeclareOp declareOp) {
// Check that all uses are inside of the loop.
if (!allUsesInLoop(declareOp->getResults()))
return failure();

// Push hlfir.declare into the beginning of the loop region.
Block &b = loop.getRegion().getBlocks().front();
declareOp->moveBefore(&b, b.begin());

// Find associated fir.alloca and push into the beginning of the loop
// region.
fir::AllocaOp allocaOp =
cast<fir::AllocaOp>(declareOp.getMemref().getDefiningOp());
Value allocaVal = allocaOp.getResult();

if (!allUsesInLoop(allocaVal))
return failure();

allocaOp->moveBefore(&b, b.begin());
return success();
}

/// Try to push a fir.alloca operation defined outside of the loop inside,
/// if all uses of that operation are contained inside of the loop.
LogicalResult pushAllocaIntoLoop(fir::AllocaOp allocaOp) {
Value store = allocaOp.getResult();

// Check that all uses are inside of the loop.
if (!allUsesInLoop(store))
return failure();

// Push fir.alloca into the beginning of the loop region.
Block &b = loop.getRegion().getBlocks().front();
allocaOp->moveBefore(&b, b.begin());
return success();
}

void processLoopArg(BlockArgument arg, llvm::ArrayRef<Value> argStores,
SmallPtrSetImpl<Operation *> &opsToDelete) {
llvm::SmallPtrSet<Operation *, 16> toDelete;
for (Value store : argStores) {
Operation *op = store.getDefiningOp();

// Skip argument if storage not defined by an operation.
if (!op)
return;

// Support HLFIR flow as well as regular FIR flow.
if (auto declareOp = dyn_cast<hlfir::DeclareOp>(op)) {
if (succeeded(pushDeclareIntoLoop(declareOp)) &&
isDeclareRemovable(declareOp)) {
// Mark hlfir.declare, fir.alloca and related uses for deletion.
for (OpOperand &use : declareOp.getOriginalBase().getUses())
toDelete.insert(use.getOwner());

for (OpOperand &use : declareOp.getBase().getUses())
toDelete.insert(use.getOwner());

Operation *allocaOp = declareOp.getMemref().getDefiningOp();
toDelete.insert(declareOp);
toDelete.insert(allocaOp);
}
} else if (auto allocaOp = dyn_cast<fir::AllocaOp>(op)) {
if (succeeded(pushAllocaIntoLoop(allocaOp)) &&
isAllocaRemovable(allocaOp)) {
// Do not make any further modifications if an address to the index
// is necessary. Otherwise, the values can be used directly from the
// loop region first block's arguments.

// Mark fir.alloca and related uses for deletion.
for (OpOperand &use : allocaOp.getResult().getUses())
toDelete.insert(use.getOwner());

// Delete now-unused fir.alloca.
toDelete.insert(allocaOp);
}
} else {
return;
}
}

// Only consider marked operations if all load, store and allocation
// operations associated with the given loop index can be removed.
opsToDelete.insert(toDelete.begin(), toDelete.end());

for (Operation *op : toDelete) {
// Replace all fir.load operations with the index as returned by the
// OpenMP loop operation.
if (isa<fir::LoadOp>(op))
op->replaceAllUsesWith(ValueRange(arg));
// Drop all uses of fir.alloca and hlfir.declare because their defining
// operations will be deleted as well.
else if (isa<fir::AllocaOp>(op) || isa<hlfir::DeclareOp>(op))
op->dropAllUses();
}
}

public:
explicit LoopProcessorHelper(LoopOpTy loop) : loop(loop) {}

void process() {
llvm::SmallPtrSet<Operation *, 16> opsToDelete;
llvm::SmallVector<llvm::SmallVector<Value>> storeAddresses;
llvm::ArrayRef<BlockArgument> loopArgs = loop.getRegion().getArguments();

// Collect arguments of the loop operation.
for (BlockArgument arg : loopArgs) {
// Find fir.store uses of these indices and gather all addresses where
// they are stored.
llvm::SmallVector<Value> &argStores = storeAddresses.emplace_back();
for (OpOperand &argUse : arg.getUses())
if (auto storeOp = dyn_cast<fir::StoreOp>(argUse.getOwner()))
argStores.push_back(storeOp.getMemref());
}

// Process all loop indices and mark them for deletion independently of each
// other.
for (auto it : llvm::zip(loopArgs, storeAddresses))
processLoopArg(std::get<0>(it), std::get<1>(it), opsToDelete);

// Delete marked operations.
for (Operation *op : opsToDelete)
op->erase();
}
};

namespace {
class OMPLoopIndexMemToRegPass
: public fir::impl::OMPLoopIndexMemToRegBase<OMPLoopIndexMemToRegPass> {
public:
void runOnOperation() override {
func::FuncOp func = getOperation();

func->walk(
[&](omp::WsLoopOp loop) { LoopProcessorHelper(loop).process(); });

func.walk(
[&](omp::SimdLoopOp loop) { LoopProcessorHelper(loop).process(); });

func.walk(
[&](omp::TaskLoopOp loop) { LoopProcessorHelper(loop).process(); });
}
};
} // namespace

std::unique_ptr<Pass> fir::createOMPLoopIndexMemToRegPass() {
return std::make_unique<OMPLoopIndexMemToRegPass>();
}
53 changes: 24 additions & 29 deletions flang/test/Lower/OpenMP/FIR/copyin.f90
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,15 @@ subroutine copyin_derived_type()
! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFcombined_parallel_worksharing_loopEx6) : !fir.ref<i32>
! CHECK: %[[VAL_2:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
! CHECK: omp.parallel {
! CHECK: %[[VAL_3:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
! CHECK: %[[VAL_4:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
! CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
! CHECK: fir.store %[[VAL_5]] to %[[VAL_4]] : !fir.ref<i32>
! CHECK: %[[VAL_3:.*]] = omp.threadprivate %[[VAL_1]] : !fir.ref<i32> -> !fir.ref<i32>
! CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
! CHECK: fir.store %[[VAL_4]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: omp.barrier
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop for (%[[VAL_9:.*]]) : i32 = (%[[VAL_6]]) to (%[[VAL_7]]) inclusive step (%[[VAL_8]]) {
! CHECK: fir.store %[[VAL_9]] to %[[VAL_3]] : !fir.ref<i32>
! CHECK: fir.call @_QPsub4(%[[VAL_4]]) {{.*}}: (!fir.ref<i32>) -> ()
! CHECK: %[[VAL_5:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32
! CHECK: omp.wsloop for (%[[VAL_9:.*]]) : i32 = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_7]]) {
! CHECK: fir.call @_QPsub4(%[[VAL_3]]) {{.*}}: (!fir.ref<i32>) -> ()
! CHECK: omp.yield
! CHECK: }
! CHECK: omp.terminator
Expand Down Expand Up @@ -269,30 +267,27 @@ subroutine common_1()
!CHECK: %[[val_7:.*]] = fir.coordinate_of %[[val_6]], %[[val_c4]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[val_8:.*]] = fir.convert %[[val_7]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: omp.parallel {
!CHECK: %[[val_9:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
!CHECK: %[[val_10:.*]] = omp.threadprivate %[[val_1]] : !fir.ref<!fir.array<8xi8>> -> !fir.ref<!fir.array<8xi8>>
!CHECK: %[[val_11:.*]] = fir.convert %[[val_10]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[val_9:.*]] = omp.threadprivate %[[val_1]] : !fir.ref<!fir.array<8xi8>> -> !fir.ref<!fir.array<8xi8>>
!CHECK: %[[val_10:.*]] = fir.convert %[[val_9]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[val_c0_0:.*]] = arith.constant 0 : index
!CHECK: %[[val_12:.*]] = fir.coordinate_of %[[val_11]], %[[val_c0_0]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[val_13:.*]] = fir.convert %[[val_12]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[val_14:.*]] = fir.convert %[[val_10]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[val_11:.*]] = fir.coordinate_of %[[val_10]], %[[val_c0_0]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[val_12:.*]] = fir.convert %[[val_11]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[val_13:.*]] = fir.convert %[[val_9]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
!CHECK: %[[val_c4_1:.*]] = arith.constant 4 : index
!CHECK: %[[val_15:.*]] = fir.coordinate_of %[[val_14]], %[[val_c4_1]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[val_16:.*]] = fir.convert %[[val_15]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[val_17:.*]] = fir.load %[[val_5]] : !fir.ref<i32>
!CHECK: fir.store %[[val_17]] to %[[val_13]] : !fir.ref<i32>
!CHECK: %[[val_18:.*]] = fir.load %[[val_8]] : !fir.ref<i32>
!CHECK: fir.store %[[val_18]] to %[[val_16]] : !fir.ref<i32>
!CHECK: %[[val_14:.*]] = fir.coordinate_of %[[val_13]], %[[val_c4_1]] : (!fir.ref<!fir.array<?xi8>>, index) -> !fir.ref<i8>
!CHECK: %[[val_15:.*]] = fir.convert %[[val_14]] : (!fir.ref<i8>) -> !fir.ref<i32>
!CHECK: %[[val_16:.*]] = fir.load %[[val_5]] : !fir.ref<i32>
!CHECK: fir.store %[[val_16]] to %[[val_12]] : !fir.ref<i32>
!CHECK: %[[val_17:.*]] = fir.load %[[val_8]] : !fir.ref<i32>
!CHECK: fir.store %[[val_17]] to %[[val_15]] : !fir.ref<i32>
!CHECK: omp.barrier
!CHECK: %[[val_c1_i32:.*]] = arith.constant 1 : i32
!CHECK: %[[val_19:.*]] = fir.load %[[val_13]] : !fir.ref<i32>
!CHECK: %[[val_18:.*]] = fir.load %[[val_12]] : !fir.ref<i32>
!CHECK: %[[val_c1_i32_2:.*]] = arith.constant 1 : i32
!CHECK: omp.wsloop for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_19]]) inclusive step (%[[val_c1_i32_2]]) {
!CHECK: fir.store %[[arg]] to %[[val_9]] : !fir.ref<i32>
!CHECK: %[[val_20:.*]] = fir.load %[[val_16]] : !fir.ref<i32>
!CHECK: %[[val_21:.*]] = fir.load %[[val_9]] : !fir.ref<i32>
!CHECK: %[[val_22:.*]] = arith.addi %[[val_20]], %[[val_21]] : i32
!CHECK: fir.store %[[val_22]] to %[[val_16]] : !fir.ref<i32>
!CHECK: omp.wsloop for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_18]]) inclusive step (%[[val_c1_i32_2]]) {
!CHECK: %[[val_19:.*]] = fir.load %[[val_15]] : !fir.ref<i32>
!CHECK: %[[val_20:.*]] = arith.addi %[[val_19]], %[[arg]] : i32
!CHECK: fir.store %[[val_20]] to %[[val_15]] : !fir.ref<i32>
!CHECK: omp.yield
!CHECK: }
!CHECK: omp.terminator
Expand Down
2 changes: 0 additions & 2 deletions flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s

!CHECK: func.func @_QPlastprivate_common() {
!CHECK: %[[val_0:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
!CHECK: %[[val_1:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFlastprivate_commonEi"}
!CHECK: %[[val_2:.*]] = fir.address_of(@c_) : !fir.ref<!fir.array<8xi8>>
!CHECK: %[[val_3:.*]] = fir.convert %[[val_2]] : (!fir.ref<!fir.array<8xi8>>) -> !fir.ref<!fir.array<?xi8>>
Expand All @@ -18,7 +17,6 @@
!CHECK: %[[val_c100_i32:.*]] = arith.constant 100 : i32
!CHECK: %[[val_c1_i32_0:.*]] = arith.constant 1 : i32
!CHECK: omp.wsloop for (%[[arg:.*]]) : i32 = (%[[val_c1_i32]]) to (%[[val_c100_i32]]) inclusive step (%[[val_c1_i32_0]]) {
!CHECK: fir.store %[[arg]] to %[[val_0]] : !fir.ref<i32>
!CHECK: %[[val_11:.*]] = arith.cmpi eq, %[[arg]], %[[val_c100_i32]] : i32
!CHECK: fir.if %[[val_11]] {
!CHECK: %[[val_12:.*]] = fir.load %[[val_9]] : !fir.ref<f32>
Expand Down
Loading