@@ -2366,12 +2366,6 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
23662366 return undef.getDefiningOp ();
23672367 };
23682368
2369- llvm::SmallVector<mlir::Type> blockArgTypes;
2370- llvm::SmallVector<mlir::Location> blockArgLocs;
2371- blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
2372- blockArgLocs.reserve (blockArgTypes.size ());
2373- mlir::Block *entryBlock;
2374-
23752369 // If an argument for the region is provided then create the block with that
23762370 // argument. Also update the symbol's address with the mlir argument value.
23772371 // e.g. For loops the argument is the induction variable. And all further
@@ -3358,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
33583352 return args;
33593353}
33603354
3355+ static llvm::SmallVector<const Fortran::semantics::Symbol *>
3356+ genLoopAndReductionVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3357+ mlir::Location &loc,
3358+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
3359+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
3360+ llvm::SmallVector<mlir::Type> &reductionTypes) {
3361+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
3362+
3363+ llvm::SmallVector<mlir::Type> blockArgTypes;
3364+ llvm::SmallVector<mlir::Location> blockArgLocs;
3365+ blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
3366+ blockArgLocs.reserve (blockArgTypes.size ());
3367+ mlir::Block *entryBlock;
3368+
3369+ if (loopArgs.size ()) {
3370+ std::size_t loopVarTypeSize = 0 ;
3371+ for (const Fortran::semantics::Symbol *arg : loopArgs)
3372+ loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
3373+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
3374+ std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
3375+ loopVarType);
3376+ std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
3377+ }
3378+ if (reductionArgs.size ()) {
3379+ llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
3380+ std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
3381+ }
3382+ entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
3383+ blockArgLocs);
3384+ // The argument is not currently in memory, so make a temporary for the
3385+ // argument, and store it there, then bind that location to the argument.
3386+ if (loopArgs.size ()) {
3387+ mlir::Operation *storeOp = nullptr ;
3388+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
3389+ mlir::Value indexVal =
3390+ fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
3391+ storeOp =
3392+ createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
3393+ }
3394+ firOpBuilder.setInsertionPointAfter (storeOp);
3395+ }
3396+ // Bind the reduction arguments to their block arguments
3397+ for (auto [arg, prv] : llvm::zip_equal (
3398+ reductionArgs,
3399+ llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
3400+ converter.bindSymbol (*arg, prv);
3401+ }
3402+
3403+ return loopArgs;
3404+ }
3405+
33613406static void
33623407createSimdLoop (Fortran::lower::AbstractConverter &converter,
33633408 Fortran::semantics::SemanticsContext &semaCtx,
@@ -3492,19 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34923537 auto *nestedEval = getCollapsedLoopEval (
34933538 eval, Fortran::lower::getCollapseValue (beginClauseList));
34943539
3540+ llvm::SmallVector<mlir::Type> reductionTypes;
3541+ reductionTypes.reserve (reductionVars.size ());
3542+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
3543+ [](mlir::Value v) { return v.getType (); });
3544+
34953545 auto ivCallback = [&](mlir::Operation *op) {
3496- return genLoopVars (op, converter, loc, iv);
3546+ return genLoopAndReductionVars (op, converter, loc, iv, reductionSymbols, reductionTypes );
34973547 };
34983548
3499- // llvm::SmallVector<mlir::Type> reductionTypes;
3500- // reductionTypes.reserve(reductionVars.size());
3501- // llvm::transform(reductionVars, std::back_inserter(reductionTypes),
3502- // [](mlir::Value v) { return v.getType(); });
3503-
35043549 createBodyOfOp<mlir::omp::WsLoopOp>(
35053550 wsLoopOp, OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval)
35063551 .setClauses (&beginClauseList)
35073552 .setDataSharingProcessor (&dsp)
3553+ .setReductions (&reductionSymbols, &reductionTypes)
35083554 .setGenRegionEntryCb (ivCallback));
35093555}
35103556
0 commit comments