Skip to content

Commit ef4fbb5

Browse files
authored
[flang][OpenACC] add pass to bufferize fir.box recipes (#163543)
When working on privatization, it is easier to work with fir.box explicitly in memory, otherwise, there is no way to express that the fir.box will end-up being a descriptor address in FIR which makes it hard to deal with data management. However, introducing fir.ref<fir.box> early can pessimize early HLFIR optimization because it is harder to reason about the aliasing of `fir.ref<fir.box>` because of the extra memory indirection. This patch introduces a pass that turns acc `!fir.box<T>` recipes into `!fir.ref<!fir.box<T>>` recipes and updated the related recipe usages to use `!fir.ref<!fir.box<T>>` (creating new alloca+store+load). It is added to flang and not OpenACC because it is specific to the `fir.box` type, so it makes little sense to make it an OpenACC generic pass and to create a new OpenACC dialect type interface for this use case.
1 parent 10be254 commit ef4fbb5

File tree

10 files changed

+597
-0
lines changed

10 files changed

+597
-0
lines changed

flang/include/flang/Optimizer/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ add_subdirectory(CodeGen)
22
add_subdirectory(Dialect)
33
add_subdirectory(HLFIR)
44
add_subdirectory(Transforms)
5+
add_subdirectory(OpenACC)
56
add_subdirectory(OpenMP)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name FIROpenACC)
3+
4+
add_public_tablegen_target(FIROpenACCPassesIncGen)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===- Passes.h - OpenACC pass entry points -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header declares the OpenACC passes specific to Fortran and FIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef FORTRAN_OPTIMIZER_OPENACC_PASSES_H
14+
#define FORTRAN_OPTIMIZER_OPENACC_PASSES_H
15+
16+
#include "mlir/IR/BuiltinOps.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Pass/PassRegistry.h"
19+
20+
#include <memory>
21+
22+
namespace fir {
23+
namespace acc {
24+
#define GEN_PASS_DECL
25+
#define GEN_PASS_REGISTRATION
26+
#include "flang/Optimizer/OpenACC/Passes.h.inc"
27+
28+
std::unique_ptr<mlir::Pass> createACCRecipeBufferizationPass();
29+
30+
} // namespace acc
31+
} // namespace fir
32+
33+
#endif // FORTRAN_OPTIMIZER_OPENACC_PASSES_H
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===-- Passes.td - flang OpenACC pass definitions -----------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_OPTIMIZER_OPENACC_PASSES
10+
#define FORTRAN_OPTIMIZER_OPENACC_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def ACCRecipeBufferization
15+
: Pass<"fir-acc-recipe-bufferization", "mlir::ModuleOp"> {
16+
let summary = "Rewrite acc.*.recipe box values to ref<box> and update uses";
17+
let description = [{
18+
Bufferizes OpenACC recipes that operate on fir.box<T> so their type and
19+
region block arguments become fir.ref<fir.box<T>> instead. This applies to
20+
acc.private.recipe, acc.firstprivate.recipe (including copy region), and
21+
acc.reduction.recipe (including combiner region).
22+
23+
For affected regions, the pass inserts required loads at the beginning of
24+
the region to preserve original uses after argument type changes. For yields
25+
of box values, the pass allocates a local fir.ref<fir.box<T>> and stores the
26+
yielded fir.box<T> into it so the region yields a reference to a box.
27+
28+
For acc.private, acc.firstprivate, and acc.reduction operations that use a
29+
bufferized recipe, the pass allocates a host-side fir.ref<fir.box<T>> before
30+
the data op and rewires the data op to use the new memory. Other users of
31+
the original data operation result (outside the paired compute op) are
32+
updated to load through the reference.
33+
}];
34+
}
35+
36+
#endif // FORTRAN_OPTIMIZER_OPENACC_PASSES
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(Support)
2+
add_subdirectory(Transforms)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
//===- ACCRecipeBufferization.cpp -----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Bufferize OpenACC recipes that yield fir.box<T> to operate on
10+
// fir.ref<fir.box<T>> and update uses accordingly.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "flang/Optimizer/Dialect/FIROps.h"
15+
#include "flang/Optimizer/OpenACC/Passes.h"
16+
#include "mlir/Dialect/OpenACC/OpenACC.h"
17+
#include "mlir/IR/Block.h"
18+
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinOps.h"
20+
#include "mlir/IR/SymbolTable.h"
21+
#include "mlir/IR/Value.h"
22+
#include "mlir/IR/Visitors.h"
23+
#include "llvm/ADT/TypeSwitch.h"
24+
25+
namespace fir::acc {
26+
#define GEN_PASS_DEF_ACCRECIPEBUFFERIZATION
27+
#include "flang/Optimizer/OpenACC/Passes.h.inc"
28+
} // namespace fir::acc
29+
30+
namespace {
31+
32+
class BufferizeInterface {
33+
public:
34+
static std::optional<mlir::Type> mustBufferize(mlir::Type recipeType) {
35+
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(recipeType))
36+
return fir::ReferenceType::get(boxTy);
37+
return std::nullopt;
38+
}
39+
40+
static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc,
41+
mlir::Value value) {
42+
return builder.create<fir::LoadOp>(loc, value);
43+
}
44+
45+
static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc,
46+
mlir::Value value) {
47+
auto alloca = builder.create<fir::AllocaOp>(loc, value.getType());
48+
builder.create<fir::StoreOp>(loc, value, alloca);
49+
return alloca;
50+
}
51+
};
52+
53+
static void bufferizeRegionArgsAndYields(mlir::Region &region,
54+
mlir::Location loc, mlir::Type oldType,
55+
mlir::Type newType) {
56+
if (region.empty())
57+
return;
58+
59+
mlir::OpBuilder builder(&region);
60+
for (mlir::BlockArgument arg : region.getArguments()) {
61+
if (arg.getType() == oldType) {
62+
arg.setType(newType);
63+
if (!arg.use_empty()) {
64+
mlir::Operation *loadOp = BufferizeInterface::load(builder, loc, arg);
65+
arg.replaceAllUsesExcept(loadOp->getResult(0), loadOp);
66+
}
67+
}
68+
}
69+
if (auto yield =
70+
llvm::dyn_cast<mlir::acc::YieldOp>(region.back().getTerminator())) {
71+
llvm::SmallVector<mlir::Value> newOperands;
72+
newOperands.reserve(yield.getNumOperands());
73+
bool changed = false;
74+
for (mlir::Value oldYieldArg : yield.getOperands()) {
75+
if (oldYieldArg.getType() == oldType) {
76+
builder.setInsertionPoint(yield);
77+
mlir::Value alloca =
78+
BufferizeInterface::placeInMemory(builder, loc, oldYieldArg);
79+
newOperands.push_back(alloca);
80+
changed = true;
81+
} else {
82+
newOperands.push_back(oldYieldArg);
83+
}
84+
}
85+
if (changed)
86+
yield->setOperands(newOperands);
87+
}
88+
}
89+
90+
static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands,
91+
llvm::StringRef recipeSymName,
92+
mlir::Operation *computeOp) {
93+
if (!recipes)
94+
return;
95+
for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) {
96+
if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() !=
97+
recipeSymName)
98+
continue;
99+
100+
mlir::Operation *dataOp = oldRes.getDefiningOp();
101+
assert(dataOp && "dataOp must be paired with computeOp");
102+
mlir::Location loc = dataOp->getLoc();
103+
mlir::OpBuilder builder(dataOp);
104+
llvm::TypeSwitch<mlir::Operation *, void>(dataOp)
105+
.Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp,
106+
mlir::acc::ReductionOp>([&](auto privateOp) {
107+
builder.setInsertionPointAfterValue(privateOp.getVar());
108+
mlir::Value alloca = BufferizeInterface::placeInMemory(
109+
builder, loc, privateOp.getVar());
110+
privateOp.getVarMutable().assign(alloca);
111+
privateOp.getAccVar().setType(alloca.getType());
112+
});
113+
114+
llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(),
115+
oldRes.getUsers().end());
116+
for (mlir::Operation *useOp : users) {
117+
if (useOp == computeOp)
118+
continue;
119+
builder.setInsertionPoint(useOp);
120+
mlir::Operation *load = BufferizeInterface::load(builder, loc, oldRes);
121+
useOp->replaceUsesOfWith(oldRes, load->getResult(0));
122+
}
123+
}
124+
}
125+
126+
class ACCRecipeBufferization
127+
: public fir::acc::impl::ACCRecipeBufferizationBase<
128+
ACCRecipeBufferization> {
129+
public:
130+
void runOnOperation() override {
131+
mlir::ModuleOp module = getOperation();
132+
133+
llvm::SmallVector<llvm::StringRef> recipeNames;
134+
module.walk([&](mlir::Operation *recipe) {
135+
llvm::TypeSwitch<mlir::Operation *, void>(recipe)
136+
.Case<mlir::acc::PrivateRecipeOp, mlir::acc::FirstprivateRecipeOp,
137+
mlir::acc::ReductionRecipeOp>([&](auto recipe) {
138+
mlir::Type oldType = recipe.getType();
139+
auto bufferizedType =
140+
BufferizeInterface::mustBufferize(recipe.getType());
141+
if (!bufferizedType)
142+
return;
143+
recipe.setTypeAttr(mlir::TypeAttr::get(*bufferizedType));
144+
mlir::Location loc = recipe.getLoc();
145+
using RecipeOp = decltype(recipe);
146+
bufferizeRegionArgsAndYields(recipe.getInitRegion(), loc, oldType,
147+
*bufferizedType);
148+
if constexpr (std::is_same_v<RecipeOp,
149+
mlir::acc::FirstprivateRecipeOp>)
150+
bufferizeRegionArgsAndYields(recipe.getCopyRegion(), loc, oldType,
151+
*bufferizedType);
152+
if constexpr (std::is_same_v<RecipeOp,
153+
mlir::acc::ReductionRecipeOp>)
154+
bufferizeRegionArgsAndYields(recipe.getCombinerRegion(), loc,
155+
oldType, *bufferizedType);
156+
bufferizeRegionArgsAndYields(recipe.getDestroyRegion(), loc,
157+
oldType, *bufferizedType);
158+
recipeNames.push_back(recipe.getSymName());
159+
});
160+
});
161+
if (recipeNames.empty())
162+
return;
163+
164+
module.walk([&](mlir::Operation *op) {
165+
llvm::TypeSwitch<mlir::Operation *, void>(op)
166+
.Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>(
167+
[&](auto computeOp) {
168+
for (llvm::StringRef recipeName : recipeNames) {
169+
if (computeOp.getPrivatizationRecipes())
170+
updateRecipeUse(computeOp.getPrivatizationRecipesAttr(),
171+
computeOp.getPrivateOperands(), recipeName,
172+
op);
173+
if (computeOp.getFirstprivatizationRecipes())
174+
updateRecipeUse(
175+
computeOp.getFirstprivatizationRecipesAttr(),
176+
computeOp.getFirstprivateOperands(), recipeName, op);
177+
if (computeOp.getReductionRecipes())
178+
updateRecipeUse(computeOp.getReductionRecipesAttr(),
179+
computeOp.getReductionOperands(),
180+
recipeName, op);
181+
}
182+
});
183+
});
184+
}
185+
};
186+
187+
} // namespace
188+
189+
std::unique_ptr<mlir::Pass> fir::acc::createACCRecipeBufferizationPass() {
190+
return std::make_unique<ACCRecipeBufferization>();
191+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_flang_library(FIROpenACCTransforms
2+
ACCRecipeBufferization.cpp
3+
4+
DEPENDS
5+
FIROpenACCPassesIncGen
6+
7+
LINK_LIBS
8+
MLIRIR
9+
MLIRPass
10+
FIRDialect
11+
MLIROpenACCDialect
12+
)

0 commit comments

Comments
 (0)