Skip to content

[flang][OpenMP] Main splitting functionality dev-complete #82003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

31 changes: 31 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);

// Propagate std::optional from input to output.
template <typename A>
std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
if (!x)
return std::nullopt;
return AsGenericExpr(std::move(*x));
}

template <typename A>
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
A &&x) {
Expand Down Expand Up @@ -430,6 +438,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}

struct ExtractSubstringHelper {
template <typename T> static std::optional<Substring> visit(T &&) {
return std::nullopt;
}

static std::optional<Substring> visit(const Substring &e) { return e; }

template <typename T>
static std::optional<Substring> visit(const Designator<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}

template <typename T>
static std::optional<Substring> visit(const Expr<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}
};

template <typename A>
std::optional<Substring> ExtractSubstring(const A &x) {
return ExtractSubstringHelper::visit(x);
}

// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
Expand Down
389 changes: 235 additions & 154 deletions flang/lib/Lower/DirectivesCommon.h

Large diffs are not rendered by default.

54 changes: 35 additions & 19 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
Fortran::parser::GetLastName(arrayElement->base);
return *name.symbol;
}
if (const auto *component =
Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
*designator)) {
return *component->component.symbol;
}
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
Expand All @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataClause dataClause, bool structured,
bool implicit, bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds,
/*treatIndexAsSection=*/true);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds,
/*treatIndexAsSection=*/true);

// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
Expand All @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType());
Expand All @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
std::string prefix =
converter.mangleName(getSymbolFromAccObject(accObject));
std::string prefix = converter.mangleName(symbol);
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
asFortran, dataClause);
Expand Down Expand Up @@ -783,16 +793,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
llvm::SmallVector<mlir::Attribute> &privatizations) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
Expand Down Expand Up @@ -1361,16 +1374,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
const auto &op =
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objects.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);

mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
Expand Down
Loading