diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h index 4bbf9777a54cc..2fb863738d62d 100644 --- a/flang/include/flang/Parser/dump-parse-tree.h +++ b/flang/include/flang/Parser/dump-parse-tree.h @@ -236,6 +236,7 @@ class ParseTreeDumper { NODE(parser, CUFKernelDoConstruct) NODE(CUFKernelDoConstruct, StarOrExpr) NODE(CUFKernelDoConstruct, Directive) + NODE(CUFKernelDoConstruct, LaunchConfiguration) NODE(parser, CUFReduction) NODE(parser, CycleStmt) NODE(parser, DataComponentDefStmt) diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index 5f5650304f998..ce0b6167de9fc 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -4527,12 +4527,17 @@ struct CUFReduction { struct CUFKernelDoConstruct { TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct); WRAPPER_CLASS(StarOrExpr, std::optional); + struct LaunchConfiguration { + TUPLE_CLASS_BOILERPLATE(LaunchConfiguration); + std::tuple, std::list, + std::optional> + t; + }; struct Directive { TUPLE_CLASS_BOILERPLATE(Directive); CharBlock source; - std::tuple, std::list, - std::list, std::optional, - std::list> + std::tuple, + std::optional, std::list> t; }; std::tuple> t; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 0e3011e73902d..da53edf7e734b 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2862,14 +2862,11 @@ class FirConverter : public Fortran::lower::AbstractConverter { if (nestedLoops > 1) n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops); - const std::list &grid = - std::get<1>(dir.t); - const std::list &block = - std::get<2>(dir.t); - const std::optional &stream = - std::get<3>(dir.t); + const auto &launchConfig = std::get>(dir.t); + const std::list &cufreds = - std::get<4>(dir.t); + std::get<2>(dir.t); llvm::SmallVector reduceOperands; llvm::SmallVector reduceAttrs; @@ -2913,35 +2910,45 @@ class FirConverter : public Fortran::lower::AbstractConverter { builder->createIntegerConstant(loc, builder->getI32Type(), 0); llvm::SmallVector gridValues; - if (!isOnlyStars(grid)) { - for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : - grid) { - if (expr.v) { - gridValues.push_back(fir::getBase( - genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); - } else { - gridValues.push_back(zero); + llvm::SmallVector blockValues; + mlir::Value streamValue; + + if (launchConfig) { + const std::list &grid = + std::get<0>(launchConfig->t); + const std::list + &block = std::get<1>(launchConfig->t); + const std::optional &stream = + std::get<2>(launchConfig->t); + if (!isOnlyStars(grid)) { + for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : + grid) { + if (expr.v) { + gridValues.push_back(fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); + } else { + gridValues.push_back(zero); + } } } - } - llvm::SmallVector blockValues; - if (!isOnlyStars(block)) { - for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : - block) { - if (expr.v) { - blockValues.push_back(fir::getBase( - genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); - } else { - blockValues.push_back(zero); + if (!isOnlyStars(block)) { + for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : + block) { + if (expr.v) { + blockValues.push_back(fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx))); + } else { + blockValues.push_back(zero); + } } } + + if (stream) + streamValue = builder->createConvert( + loc, builder->getI32Type(), + fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx))); } - mlir::Value streamValue; - if (stream) - streamValue = builder->createConvert( - loc, builder->getI32Type(), - fir::getBase( - genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx))); const auto &outerDoConstruct = std::get>(kernel.t); diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp index 5057e89164c9f..730165613d91d 100644 --- a/flang/lib/Parser/executable-parsers.cpp +++ b/flang/lib/Parser/executable-parsers.cpp @@ -563,11 +563,15 @@ TYPE_PARSER(("REDUCTION"_tok || "REDUCE"_tok) >> parenthesized(construct(Parser{}, ":" >> nonemptyList(scalar(variable))))) +TYPE_PARSER("<<<" >> + construct(gridOrBlock, + "," >> gridOrBlock, + maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>")) + TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >> construct( - maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock, - "," >> gridOrBlock, - maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>", + maybe(parenthesized(scalarIntConstantExpr)), + maybe(Parser{}), many(Parser{}) / endDirective))) TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US, extension(construct( diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 20022f8fa984c..4b511da69832c 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2932,11 +2932,9 @@ class UnparseVisitor { Word("*"); } } - void Unparse(const CUFKernelDoConstruct::Directive &x) { - Word("!$CUF KERNEL DO"); - Walk(" (", std::get>(x.t), ")"); + void Unparse(const CUFKernelDoConstruct::LaunchConfiguration &x) { Word(" <<<"); - const auto &grid{std::get<1>(x.t)}; + const auto &grid{std::get<0>(x.t)}; if (grid.empty()) { Word("*"); } else if (grid.size() == 1) { @@ -2945,7 +2943,7 @@ class UnparseVisitor { Walk("(", grid, ",", ")"); } Word(","); - const auto &block{std::get<2>(x.t)}; + const auto &block{std::get<1>(x.t)}; if (block.empty()) { Word("*"); } else if (block.size() == 1) { @@ -2953,10 +2951,16 @@ class UnparseVisitor { } else { Walk("(", block, ",", ")"); } - if (const auto &stream{std::get<3>(x.t)}) { + if (const auto &stream{std::get<2>(x.t)}) { Word(",STREAM="), Walk(*stream); } Word(">>>"); + } + void Unparse(const CUFKernelDoConstruct::Directive &x) { + Word("!$CUF KERNEL DO"); + Walk(" (", std::get>(x.t), ")"); + Walk(std::get>( + x.t)); Walk(" ", std::get>(x.t), " "); Word("\n"); } diff --git a/flang/test/Parser/cuf-sanity-common b/flang/test/Parser/cuf-sanity-common index 7005ef07b2265..816e03bed7220 100644 --- a/flang/test/Parser/cuf-sanity-common +++ b/flang/test/Parser/cuf-sanity-common @@ -31,6 +31,9 @@ module m !$cuf kernel do <<<1, (2, 3), stream = 1>>> do j = 1, 10 end do + !$cuf kernel do + do j = 1, 10 + end do !$cuf kernel do <<<*, *>>> reduce(+:x,y) reduce(*:z) do j = 1, 10 x = x + a(j)