@@ -3352,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
33523352 return args;
33533353}
33543354
3355+ static llvm::SmallVector<const Fortran::semantics::Symbol *>
3356+ genLoopAndReductionVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3357+ mlir::Location &loc,
3358+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
3359+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
3360+ llvm::SmallVector<mlir::Type> &reductionTypes) {
3361+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
3362+
3363+ llvm::SmallVector<mlir::Type> blockArgTypes;
3364+ llvm::SmallVector<mlir::Location> blockArgLocs;
3365+ blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
3366+ blockArgLocs.reserve (blockArgTypes.size ());
3367+ mlir::Block *entryBlock;
3368+
3369+ if (loopArgs.size ()) {
3370+ std::size_t loopVarTypeSize = 0 ;
3371+ for (const Fortran::semantics::Symbol *arg : loopArgs)
3372+ loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
3373+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
3374+ std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
3375+ loopVarType);
3376+ std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
3377+ }
3378+ if (reductionArgs.size ()) {
3379+ llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
3380+ std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
3381+ }
3382+ entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
3383+ blockArgLocs);
3384+ // The argument is not currently in memory, so make a temporary for the
3385+ // argument, and store it there, then bind that location to the argument.
3386+ if (loopArgs.size ()) {
3387+ mlir::Operation *storeOp = nullptr ;
3388+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
3389+ mlir::Value indexVal =
3390+ fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
3391+ storeOp =
3392+ createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
3393+ }
3394+ firOpBuilder.setInsertionPointAfter (storeOp);
3395+ }
3396+ // Bind the reduction arguments to their block arguments
3397+ for (auto [arg, prv] : llvm::zip_equal (
3398+ reductionArgs,
3399+ llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
3400+ converter.bindSymbol (*arg, prv);
3401+ }
3402+
3403+ return loopArgs;
3404+ }
3405+
33553406static void
33563407createSimdLoop (Fortran::lower::AbstractConverter &converter,
33573408 Fortran::semantics::SemanticsContext &semaCtx,
@@ -3429,6 +3480,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34293480 llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
34303481 llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
34313482 llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
3483+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
34323484 mlir::omp::ClauseOrderKindAttr orderClauseOperand;
34333485 mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
34343486 mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3440,7 +3492,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34403492 cp.processCollapse (loc, eval, lowerBound, upperBound, step, iv,
34413493 loopVarTypeSize);
34423494 cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
3443- cp.processReduction (loc, reductionVars, reductionDeclSymbols);
3495+ cp.processReduction (loc, reductionVars, reductionDeclSymbols,
3496+ &reductionSymbols);
34443497 cp.processTODO <Fortran::parser::OmpClause::Linear,
34453498 Fortran::parser::OmpClause::Order>(loc, ompDirective);
34463499
@@ -3484,14 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34843537 auto *nestedEval = getCollapsedLoopEval (
34853538 eval, Fortran::lower::getCollapseValue (beginClauseList));
34863539
3540+ llvm::SmallVector<mlir::Type> reductionTypes;
3541+ reductionTypes.reserve (reductionVars.size ());
3542+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
3543+ [](mlir::Value v) { return v.getType (); });
3544+
34873545 auto ivCallback = [&](mlir::Operation *op) {
3488- return genLoopVars (op, converter, loc, iv);
3546+ return genLoopAndReductionVars (op, converter, loc, iv, reductionSymbols, reductionTypes );
34893547 };
34903548
34913549 createBodyOfOp<mlir::omp::WsLoopOp>(
34923550 wsLoopOp, OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval)
34933551 .setClauses (&beginClauseList)
34943552 .setDataSharingProcessor (&dsp)
3553+ .setReductions (&reductionSymbols, &reductionTypes)
34953554 .setGenRegionEntryCb (ivCallback));
34963555}
34973556
@@ -3594,12 +3653,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
35943653 // 2.9.3.1 SIMD construct
35953654 createSimdLoop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
35963655 currentLocation);
3656+ genOpenMPReduction (converter, semaCtx, loopOpClauseList);
35973657 } else {
35983658 createWsLoop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
35993659 endClauseList, currentLocation);
36003660 }
3601-
3602- genOpenMPReduction (converter, semaCtx, loopOpClauseList);
36033661}
36043662
36053663static void
0 commit comments