@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396396
397397// / Translates the blocks contained in the given region and appends them to at
398398// / the current insertion point of `builder`. The operations of the entry block
399- // / are appended to the current insertion block, which is not expected to have a
400- // / terminator. If set, `continuationBlockArgs` is populated with translated
401- // / values that correspond to the values omp.yield'ed from the region.
399+ // / are appended to the current insertion block. If set, `continuationBlockArgs`
400+ // / is populated with translated values that correspond to the values
401+ // / omp.yield'ed from the region.
402402static LogicalResult inlineConvertOmpRegions (
403403 Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
404404 LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409409 // Special case for single-block regions that don't create additional blocks:
410410 // insert operations without creating additional blocks.
411411 if (llvm::hasSingleElement (region)) {
412+ llvm::Instruction *potentialTerminator =
413+ builder.GetInsertBlock ()->empty () ? nullptr
414+ : &builder.GetInsertBlock ()->back ();
415+
416+ if (potentialTerminator && potentialTerminator->isTerminator ())
417+ potentialTerminator->removeFromParent ();
412418 moduleTranslation.mapBlock (®ion.front (), builder.GetInsertBlock ());
419+
413420 if (failed (moduleTranslation.convertBlock (
414421 region.front (), /* ignoreArguments=*/ true , builder)))
415422 return failure ();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423430 // Drop the mapping that is no longer necessary so that the same region can
424431 // be processed multiple times.
425432 moduleTranslation.forgetMapping (region);
433+
434+ if (potentialTerminator && potentialTerminator->isTerminator ())
435+ potentialTerminator->insertAfter (&builder.GetInsertBlock ()->back ());
436+
426437 return success ();
427438 }
428439
@@ -1000,11 +1011,50 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10001011 return success ();
10011012}
10021013
1014+ // / A RAII class that on construction replaces the region arguments of the
1015+ // / parallel op (which correspond to private variables) with the actual private
1016+ // / variables they correspond to. This prepares the parallel op so that it
1017+ // / matches what is expected by the OMPIRBuilder.
1018+ // /
1019+ // / On destruction, it restores the original state of the operation so that on
1020+ // / the MLIR side, the op is not affected by conversion to LLVM IR.
1021+ class OmpParallelOpConversionManager {
1022+ public:
1023+ OmpParallelOpConversionManager (omp::ParallelOp opInst)
1024+ : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1025+ privateArgBeginIdx (opInst.getNumReductionVars()),
1026+ privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1027+ auto privateVarsIt = privateVars.begin ();
1028+
1029+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1030+ ++argIdx, ++privateVarsIt)
1031+ mlir::replaceAllUsesInRegionWith (region.getArgument (argIdx),
1032+ *privateVarsIt, region);
1033+ }
1034+
1035+ ~OmpParallelOpConversionManager () {
1036+ auto privateVarsIt = privateVars.begin ();
1037+
1038+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1039+ ++argIdx, ++privateVarsIt)
1040+ mlir::replaceAllUsesInRegionWith (*privateVarsIt,
1041+ region.getArgument (argIdx), region);
1042+ }
1043+
1044+ private:
1045+ Region ®ion;
1046+ OperandRange privateVars;
1047+ unsigned privateArgBeginIdx;
1048+ unsigned privateArgEndIdx;
1049+ };
1050+
10031051// / Converts the OpenMP parallel operation to LLVM IR.
10041052static LogicalResult
10051053convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10061054 LLVM::ModuleTranslation &moduleTranslation) {
10071055 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1056+ OmpParallelOpConversionManager raii (opInst);
1057+
10081058 // TODO: support error propagation in OpenMPIRBuilder and use it instead of
10091059 // relying on captured variables.
10101060 LogicalResult bodyGenStatus = success ();
@@ -1086,12 +1136,81 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10861136
10871137 // TODO: Perform appropriate actions according to the data-sharing
10881138 // attribute (shared, private, firstprivate, ...) of variables.
1089- // Currently defaults to shared .
1139+ // Currently shared and private are supported .
10901140 auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
10911141 llvm::Value &, llvm::Value &vPtr,
10921142 llvm::Value *&replacementValue) -> InsertPointTy {
10931143 replacementValue = &vPtr;
10941144
1145+ // If this is a private value, this lambda will return the corresponding
1146+ // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1147+ // returned.
1148+ auto [privVar, privatizerClone] =
1149+ [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1150+ if (!opInst.getPrivateVars ().empty ()) {
1151+ auto privVars = opInst.getPrivateVars ();
1152+ auto privatizers = opInst.getPrivatizers ();
1153+
1154+ for (auto [privVar, privatizerAttr] :
1155+ llvm::zip_equal (privVars, *privatizers)) {
1156+ // Find the MLIR private variable corresponding to the LLVM value
1157+ // being privatized.
1158+ llvm::Value *llvmPrivVar = moduleTranslation.lookupValue (privVar);
1159+ if (llvmPrivVar != &vPtr)
1160+ continue ;
1161+
1162+ SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1163+ omp::PrivateClauseOp privatizer =
1164+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1165+ opInst, privSym);
1166+
1167+ // Clone the privatizer in case it is used by more than one parallel
1168+ // region. The privatizer is processed in-place (see below) before it
1169+ // gets inlined in the parallel region and therefore processing the
1170+ // original op is dangerous.
1171+ return {privVar, privatizer.clone ()};
1172+ }
1173+ }
1174+
1175+ return {mlir::Value (), omp::PrivateClauseOp ()};
1176+ }();
1177+
1178+ if (privVar) {
1179+ if (privatizerClone.getDataSharingType () ==
1180+ omp::DataSharingClauseType::FirstPrivate) {
1181+ privatizerClone.emitOpError (
1182+ " TODO: delayed privatization is not "
1183+ " supported for `firstprivate` clauses yet." );
1184+ bodyGenStatus = failure ();
1185+ return codeGenIP;
1186+ }
1187+
1188+ Region &allocRegion = privatizerClone.getAllocRegion ();
1189+
1190+ // Replace the privatizer block argument with mlir value being privatized.
1191+ // This way, the body of the privatizer will be changed from using the
1192+ // region/block argument to the value being privatized.
1193+ auto allocRegionArg = allocRegion.getArgument (0 );
1194+ replaceAllUsesInRegionWith (allocRegionArg, privVar, allocRegion);
1195+
1196+ auto oldIP = builder.saveIP ();
1197+ builder.restoreIP (allocaIP);
1198+
1199+ SmallVector<llvm::Value *, 1 > yieldedValues;
1200+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1201+ moduleTranslation, &yieldedValues))) {
1202+ opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1203+ " op in the parallel region" );
1204+ bodyGenStatus = failure ();
1205+ } else {
1206+ assert (yieldedValues.size () == 1 );
1207+ replacementValue = yieldedValues.front ();
1208+ }
1209+
1210+ privatizerClone.erase ();
1211+ builder.restoreIP (oldIP);
1212+ }
1213+
10951214 return codeGenIP;
10961215 };
10971216
@@ -1635,7 +1754,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
16351754// A small helper structure to contain data gathered
16361755// for map lowering and coalese it into one area and
16371756// avoiding extra computations such as searches in the
1638- // llvm module for lowered mapped varibles or checking
1757+ // llvm module for lowered mapped variables or checking
16391758// if something is declare target (and retrieving the
16401759// value) more than neccessary.
16411760struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2973,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
28542973 moduleTranslation);
28552974 return failure ();
28562975 })
2857- .Case (
2858- " omp.requires " ,
2859- [&](Attribute attr) {
2860- if ( auto requiresAttr = attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2861- using Requires = omp::ClauseRequires;
2862- Requires flags = requiresAttr.getValue ();
2863- llvm::OpenMPIRBuilderConfig &config =
2864- moduleTranslation.getOpenMPBuilder ()->Config ;
2865- config.setHasRequiresReverseOffload (
2866- bitEnumContainsAll (flags, Requires::reverse_offload));
2867- config.setHasRequiresUnifiedAddress (
2868- bitEnumContainsAll (flags, Requires::unified_address));
2869- config.setHasRequiresUnifiedSharedMemory (
2870- bitEnumContainsAll (flags, Requires::unified_shared_memory));
2871- config.setHasRequiresDynamicAllocators (
2872- bitEnumContainsAll (flags, Requires::dynamic_allocators));
2873- return success ();
2874- }
2875- return failure ();
2876- })
2976+ .Case (" omp.requires " ,
2977+ [&](Attribute attr) {
2978+ if ( auto requiresAttr =
2979+ attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2980+ using Requires = omp::ClauseRequires;
2981+ Requires flags = requiresAttr.getValue ();
2982+ llvm::OpenMPIRBuilderConfig &config =
2983+ moduleTranslation.getOpenMPBuilder ()->Config ;
2984+ config.setHasRequiresReverseOffload (
2985+ bitEnumContainsAll (flags, Requires::reverse_offload));
2986+ config.setHasRequiresUnifiedAddress (
2987+ bitEnumContainsAll (flags, Requires::unified_address));
2988+ config.setHasRequiresUnifiedSharedMemory (
2989+ bitEnumContainsAll (flags, Requires::unified_shared_memory));
2990+ config.setHasRequiresDynamicAllocators (
2991+ bitEnumContainsAll (flags, Requires::dynamic_allocators));
2992+ return success ();
2993+ }
2994+ return failure ();
2995+ })
28772996 .Default ([](Attribute) {
28782997 // Fall through for omp attributes that do not require lowering.
28792998 return success ();
@@ -2988,12 +3107,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
29883107 .Case ([&](omp::TargetOp) {
29893108 return convertOmpTarget (*op, builder, moduleTranslation);
29903109 })
2991- .Case <omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
2992- // No-op, should be handled by relevant owning operations e.g.
2993- // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
2994- // discarded
2995- return success ();
2996- })
3110+ .Case <omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3111+ [&](auto op) {
3112+ // No-op, should be handled by relevant owning operations e.g.
3113+ // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3114+ // discarded
3115+ return success ();
3116+ })
29973117 .Default ([&](Operation *inst) {
29983118 return inst->emitError (" unsupported OpenMP operation: " )
29993119 << inst->getName ();
0 commit comments