4747
4848using namespace mlir ;
4949
50+ llvm::SmallDenseMap<llvm::Value *, llvm::Type *> ReductionVarToType;
51+ llvm::OpenMPIRBuilder::InsertPointTy
52+ parallelAllocaIP; // TODO: change this alloca IP to point to originalvar
53+ // allocaIP. ReductionDecl need to be linked to scan var.
5054namespace {
5155static llvm::omp::ScheduleKind
5256convertToScheduleKind (std::optional<omp::ClauseScheduleKind> schedKind) {
@@ -86,7 +90,9 @@ class OpenMPLoopInfoStackFrame
8690 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
8791public:
8892 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (OpenMPLoopInfoStackFrame)
89- llvm::CanonicalLoopInfo *loopInfo = nullptr ;
93+ // For constructs like scan, one Loop info frame can contain multiple
94+ // Canonical Loops
95+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
9096};
9197
9298// / Custom error class to signal translation errors that don't need reporting,
@@ -169,6 +175,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
169175 if (op.getDistScheduleChunkSize ())
170176 result = todo (" dist_schedule with chunk_size" );
171177 };
178+ auto checkExclusive = [&todo](auto op, LogicalResult &result) {
179+ if (!op.getExclusiveVars ().empty ())
180+ result = todo (" exclusive" );
181+ };
172182 auto checkHint = [](auto op, LogicalResult &) {
173183 if (op.getHint ())
174184 op.emitWarning (" hint clause discarded" );
@@ -232,8 +242,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
232242 op.getReductionSyms ())
233243 result = todo (" reduction" );
234244 if (op.getReductionMod () &&
235- op.getReductionMod ().value () != omp::ReductionModifier::defaultmod )
236- result = todo (" reduction with modifier" );
245+ op.getReductionMod ().value () == omp::ReductionModifier::task )
246+ result = todo (" reduction with task modifier" );
237247 };
238248 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
239249 if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
@@ -253,6 +263,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
253263 checkOrder (op, result);
254264 })
255265 .Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
266+ .Case ([&](omp::ScanOp op) { checkExclusive (op, result); })
256267 .Case ([&](omp::SectionsOp op) {
257268 checkAllocate (op, result);
258269 checkPrivate (op, result);
@@ -382,15 +393,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
382393// / Find the loop information structure for the loop nest being translated. It
383394// / will return a `null` value unless called from the translation function for
384395// / a loop wrapper operation after successfully translating its body.
385- static llvm::CanonicalLoopInfo *
386- findCurrentLoopInfo (LLVM::ModuleTranslation &moduleTranslation) {
387- llvm::CanonicalLoopInfo *loopInfo = nullptr ;
396+ static SmallVector< llvm::CanonicalLoopInfo *>
397+ findCurrentLoopInfos (LLVM::ModuleTranslation &moduleTranslation) {
398+ SmallVector< llvm::CanonicalLoopInfo *> loopInfos ;
388399 moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
389400 [&](OpenMPLoopInfoStackFrame &frame) {
390- loopInfo = frame.loopInfo ;
401+ loopInfos = frame.loopInfos ;
391402 return WalkResult::interrupt ();
392403 });
393- return loopInfo ;
404+ return loopInfos ;
394405}
395406
396407// / Converts the given region that appears within an OpenMP dialect operation to
@@ -1133,6 +1144,11 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
11331144 // variables. Although this could be done after allocas, we don't want to mess
11341145 // up with the alloca insertion point.
11351146 for (unsigned i = 0 ; i < op.getNumReductionVars (); ++i) {
1147+
1148+ llvm::Type *reductionType =
1149+ moduleTranslation.convertType (reductionDecls[i].getType ());
1150+ ReductionVarToType[privateReductionVariables[i]] = reductionType;
1151+
11361152 SmallVector<llvm::Value *, 1 > phis;
11371153
11381154 // map block argument to initializer region
@@ -1206,9 +1222,11 @@ static void collectReductionInfo(
12061222 atomicGen = owningAtomicReductionGens[i];
12071223 llvm::Value *variable =
12081224 moduleTranslation.lookupValue (loop.getReductionVars ()[i]);
1225+ llvm::Type *reductionType =
1226+ moduleTranslation.convertType (reductionDecls[i].getType ());
1227+ ReductionVarToType[privateReductionVariables[i]] = reductionType;
12091228 reductionInfos.push_back (
1210- {moduleTranslation.convertType (reductionDecls[i].getType ()), variable,
1211- privateReductionVariables[i],
1229+ {reductionType, variable, privateReductionVariables[i],
12121230 /* EvaluationKind=*/ llvm::OpenMPIRBuilder::EvalKind::Scalar,
12131231 owningReductionGens[i],
12141232 /* ReductionGenClang=*/ nullptr , atomicGen});
@@ -2342,27 +2360,60 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23422360 if (failed (handleError (regionBlock, opInst)))
23432361 return failure ();
23442362
2345- builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2346- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2347-
2348- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2349- ompBuilder->applyWorkshareLoop (
2350- ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
2351- convertToScheduleKind (schedule), chunk, isSimd,
2352- scheduleMod == omp::ScheduleModifier::monotonic,
2353- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2354- workshareLoopType);
2355-
2356- if (failed (handleError (wsloopIP, opInst)))
2357- return failure ();
2358-
2359- // Process the reductions if required.
2360- if (failed (createReductionsAndCleanup (
2361- wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2362- privateReductionVariables, isByRef, wsloopOp.getNowait (),
2363- /* isTeamsReduction=*/ false )))
2364- return failure ();
2363+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2364+ findCurrentLoopInfos (moduleTranslation);
2365+ auto inputLoopFinishIp = loopInfos.front ()->getAfterIP ();
2366+ bool isInScanRegion =
2367+ wsloopOp.getReductionMod () && (wsloopOp.getReductionMod ().value () ==
2368+ mlir::omp::ReductionModifier::inscan);
2369+ if (isInScanRegion) {
2370+ builder.restoreIP (inputLoopFinishIp);
2371+ SmallVector<OwningReductionGen> owningReductionGens;
2372+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2373+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2374+ collectReductionInfo (wsloopOp, builder, moduleTranslation, reductionDecls,
2375+ owningReductionGens, owningAtomicReductionGens,
2376+ privateReductionVariables, reductionInfos);
2377+ llvm::BasicBlock *cont = splitBB (builder, false , " omp.scan.loop.cont" );
2378+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
2379+ ompBuilder->emitScanReduction (builder.saveIP (), reductionInfos);
2380+ if (failed (handleError (redIP, opInst)))
2381+ return failure ();
23652382
2383+ builder.restoreIP (*redIP);
2384+ builder.CreateBr (cont);
2385+ }
2386+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2387+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2388+ ompBuilder->applyWorkshareLoop (
2389+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
2390+ convertToScheduleKind (schedule), chunk, isSimd,
2391+ scheduleMod == omp::ScheduleModifier::monotonic,
2392+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2393+ workshareLoopType);
2394+
2395+ if (failed (handleError (wsloopIP, opInst)))
2396+ return failure ();
2397+ }
2398+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2399+ if (isInScanRegion) {
2400+ SmallVector<Region *> reductionRegions;
2401+ llvm::transform (reductionDecls, std::back_inserter (reductionRegions),
2402+ [](omp::DeclareReductionOp reductionDecl) {
2403+ return &reductionDecl.getCleanupRegion ();
2404+ });
2405+ if (failed (inlineOmpRegionCleanup (
2406+ reductionRegions, privateReductionVariables, moduleTranslation,
2407+ builder, " omp.reduction.cleanup" )))
2408+ return failure ();
2409+ } else {
2410+ // Process the reductions if required.
2411+ if (failed (createReductionsAndCleanup (
2412+ wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2413+ privateReductionVariables, isByRef, wsloopOp.getNowait (),
2414+ /* isTeamsReduction=*/ false )))
2415+ return failure ();
2416+ }
23662417 return cleanupPrivateVars (builder, moduleTranslation, wsloopOp.getLoc (),
23672418 privateVarsInfo.llvmVars ,
23682419 privateVarsInfo.privatizers );
@@ -2528,6 +2579,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
25282579
25292580 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
25302581 findAllocaInsertPoint (builder, moduleTranslation);
2582+ parallelAllocaIP = allocaIP;
25312583 llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
25322584
25332585 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
@@ -2553,6 +2605,64 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
25532605 llvm_unreachable (" Unknown ClauseOrderKind kind" );
25542606}
25552607
2608+ static LogicalResult
2609+ convertOmpScan (Operation &opInst, llvm::IRBuilderBase &builder,
2610+ LLVM::ModuleTranslation &moduleTranslation) {
2611+ if (failed (checkImplementationStatus (opInst)))
2612+ return failure ();
2613+ auto scanOp = cast<omp::ScanOp>(opInst);
2614+ bool isInclusive = scanOp.hasInclusiveVars ();
2615+ SmallVector<llvm::Value *> llvmScanVars;
2616+ SmallVector<llvm::Type *> llvmScanVarsType;
2617+ mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars ();
2618+ if (!isInclusive)
2619+ mlirScanVars = scanOp.getExclusiveVars ();
2620+ for (auto val : mlirScanVars) {
2621+ llvm::Value *llvmVal = moduleTranslation.lookupValue (val);
2622+ llvmScanVars.push_back (llvmVal);
2623+ llvmScanVarsType.push_back (ReductionVarToType[llvmVal]);
2624+ val.getDefiningOp ();
2625+ }
2626+ auto parallelOp = scanOp->getParentOfType <omp::ParallelOp>();
2627+ if (!parallelOp) {
2628+ return failure ();
2629+ }
2630+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP = parallelAllocaIP;
2631+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2632+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2633+ moduleTranslation.getOpenMPBuilder ()->createScan (
2634+ ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive);
2635+ if (failed (handleError (afterIP, opInst)))
2636+ return failure ();
2637+
2638+ builder.restoreIP (*afterIP);
2639+
2640+ // TODO: The argument of LoopnestOp is stored into the index variable and this
2641+ // variable is used across scan operation. However that makes the mlir
2642+ // invalid.(`Intra-iteration dependences from a statement in the structured
2643+ // block sequence that precede a scan directive to a statement in the
2644+ // structured block sequence that follows a scan directive must not exist,
2645+ // except for dependences for the list items specified in an inclusive or
2646+ // exclusive clause.`). The argument of LoopNestOp need to be loaded again
2647+ // after ScanOp again so mlir generated is valid.
2648+ auto parentOp = scanOp->getParentOp ();
2649+ auto loopOp = cast<omp::LoopNestOp>(parentOp);
2650+ if (loopOp) {
2651+ auto &firstBlock = *(scanOp->getParentRegion ()->getBlocks ()).begin ();
2652+ auto &ins = *(firstBlock.begin ());
2653+ if (isa<LLVM::StoreOp>(ins)) {
2654+ LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
2655+ auto src = moduleTranslation.lookupValue (storeOp->getOperand (0 ));
2656+ if (src == moduleTranslation.lookupValue (
2657+ (loopOp.getRegion ().getArguments ())[0 ])) {
2658+ auto dest = moduleTranslation.lookupValue (storeOp->getOperand (1 ));
2659+ builder.CreateStore (src, dest);
2660+ }
2661+ }
2662+ }
2663+ return success ();
2664+ }
2665+
25562666// / Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
25572667static LogicalResult
25582668convertOmpSimd (Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2626,13 +2736,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
26262736 return failure ();
26272737
26282738 builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2629- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2630- ompBuilder->applySimd (loopInfo, alignedVars,
2631- simdOp.getIfExpr ()
2632- ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2633- : nullptr ,
2634- order, simdlen, safelen);
2635-
2739+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2740+ findCurrentLoopInfos (moduleTranslation);
2741+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2742+ ompBuilder->applySimd (
2743+ loopInfo, alignedVars,
2744+ simdOp.getIfExpr () ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2745+ : nullptr ,
2746+ order, simdlen, safelen);
2747+ }
26362748 return cleanupPrivateVars (builder, moduleTranslation, simdOp.getLoc (),
26372749 privateVarsInfo.llvmVars ,
26382750 privateVarsInfo.privatizers );
@@ -2698,16 +2810,51 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
26982810 ompLoc.DL );
26992811 computeIP = loopInfos.front ()->getPreheaderIP ();
27002812 }
2813+ if (auto wsloopOp = loopOp->getParentOfType <omp::WsloopOp>()) {
2814+ bool isInScanRegion =
2815+ wsloopOp.getReductionMod () && (wsloopOp.getReductionMod ().value () ==
2816+ mlir::omp::ReductionModifier::inscan);
2817+ if (isInScanRegion) {
2818+ // TODO: Handle nesting if Scan loop is nested in a loop
2819+ assert (loopOp.getNumLoops () == 1 );
2820+ llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
2821+ ompBuilder->createCanonicalScanLoops (
2822+ loc, bodyGen, lowerBound, upperBound, step,
2823+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP,
2824+ " loop" );
2825+
2826+ if (failed (handleError (loopResults, *loopOp)))
2827+ return failure ();
2828+ auto inputLoop = loopResults->front ();
2829+ auto scanLoop = loopResults->back ();
2830+ moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
2831+ [&](OpenMPLoopInfoStackFrame &frame) {
2832+ frame.loopInfos .push_back (inputLoop);
2833+ frame.loopInfos .push_back (scanLoop);
2834+ return WalkResult::interrupt ();
2835+ });
2836+ builder.restoreIP (scanLoop->getAfterIP ());
2837+ return success ();
2838+ } else {
2839+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2840+ ompBuilder->createCanonicalLoop (
2841+ loc, bodyGen, lowerBound, upperBound, step,
2842+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
2843+ if (failed (handleError (loopResult, *loopOp)))
2844+ return failure ();
2845+ loopInfos.push_back (*loopResult);
2846+ }
2847+ } else {
2848+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2849+ ompBuilder->createCanonicalLoop (
2850+ loc, bodyGen, lowerBound, upperBound, step,
2851+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
27012852
2702- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2703- ompBuilder->createCanonicalLoop (
2704- loc, bodyGen, lowerBound, upperBound, step,
2705- /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
2706-
2707- if (failed (handleError (loopResult, *loopOp)))
2708- return failure ();
2853+ if (failed (handleError (loopResult, *loopOp)))
2854+ return failure ();
27092855
2710- loopInfos.push_back (*loopResult);
2856+ loopInfos.push_back (*loopResult);
2857+ }
27112858 }
27122859
27132860 // Collapse loops. Store the insertion point because LoopInfos may get
@@ -2719,7 +2866,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
27192866 // after applying transformations.
27202867 moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
27212868 [&](OpenMPLoopInfoStackFrame &frame) {
2722- frame.loopInfo = ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {});
2869+ frame.loopInfos .push_back (
2870+ ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {}));
27232871 return WalkResult::interrupt ();
27242872 });
27252873
@@ -4328,19 +4476,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
43284476 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
43294477 bool loopNeedsBarrier = false ;
43304478 llvm::Value *chunk = nullptr ;
4331-
4332- llvm::CanonicalLoopInfo *loopInfo =
4333- findCurrentLoopInfo (moduleTranslation);
4334- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4335- ompBuilder->applyWorkshareLoop (
4336- ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
4337- convertToScheduleKind (schedule), chunk, isSimd,
4338- scheduleMod == omp::ScheduleModifier::monotonic,
4339- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4340- workshareLoopType);
4341-
4342- if (!wsloopIP)
4343- return wsloopIP.takeError ();
4479+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
4480+ findCurrentLoopInfos (moduleTranslation);
4481+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
4482+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4483+ ompBuilder->applyWorkshareLoop (
4484+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
4485+ convertToScheduleKind (schedule), chunk, isSimd,
4486+ scheduleMod == omp::ScheduleModifier::monotonic,
4487+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4488+ workshareLoopType);
4489+
4490+ if (!wsloopIP)
4491+ return wsloopIP.takeError ();
4492+ }
43444493 }
43454494
43464495 if (failed (cleanupPrivateVars (builder, moduleTranslation,
@@ -5373,6 +5522,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
53735522 .Case ([&](omp::SimdOp) {
53745523 return convertOmpSimd (*op, builder, moduleTranslation);
53755524 })
5525+ .Case ([&](omp::ScanOp) {
5526+ return convertOmpScan (*op, builder, moduleTranslation);
5527+ })
53765528 .Case ([&](omp::AtomicReadOp) {
53775529 return convertOmpAtomicRead (*op, builder, moduleTranslation);
53785530 })
0 commit comments