@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174174 if (op.getHint ())
175175 op.emitWarning (" hint clause discarded" );
176176 };
177- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178- if (!op.getHostEvalVars ().empty ())
179- result = todo (" host_eval" );
180- };
181177 auto checkIf = [&todo](auto op, LogicalResult &result) {
182178 if (op.getIfExpr ())
183179 result = todo (" if" );
@@ -228,10 +224,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
228224 op.getReductionSyms ())
229225 result = todo (" reduction" );
230226 };
231- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
232- if (op.getThreadLimit ())
233- result = todo (" thread_limit" );
234- };
235227 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
236228 if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
237229 op.getTaskReductionSyms ())
@@ -295,7 +287,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
295287 checkAllocate (op, result);
296288 checkDevice (op, result);
297289 checkHasDeviceAddr (op, result);
298- checkHostEval (op, result);
290+
291+ // Host evaluated clauses are supported, except for target SPMD loop
292+ // bounds.
293+ for (BlockArgument arg :
294+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
295+ for (Operation *user : arg.getUsers ())
296+ if (isa<omp::LoopNestOp>(user))
297+ result = op.emitError (" not yet implemented: host evaluation of "
298+ " loop bounds in omp.target operation" );
299+
299300 checkIf (op, result);
300301 checkInReduction (op, result);
301302 checkIsDevicePtr (op, result);
@@ -316,7 +317,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
316317 " structures in omp.target operation" );
317318 }
318319 }
319- checkThreadLimit (op, result);
320320 })
321321 .Default ([](Operation &) {
322322 // Assume all clauses for an operation can be translated unless they are
@@ -3800,6 +3800,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38003800 return builder.saveIP ();
38013801}
38023802
3803+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3804+ // / operation and populate output variables with their corresponding host value
3805+ // / (i.e. operand evaluated outside of the target region), based on their uses
3806+ // / inside of the target region.
3807+ // /
3808+ // / Loop bounds and steps are only optionally populated, if output vectors are
3809+ // / provided.
3810+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3811+ Value &numTeamsLower, Value &numTeamsUpper,
3812+ Value &threadLimit) {
3813+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3814+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3815+ blockArgIface.getHostEvalBlockArgs ())) {
3816+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3817+
3818+ for (Operation *user : blockArg.getUsers ()) {
3819+ llvm::TypeSwitch<Operation *>(user)
3820+ .Case ([&](omp::TeamsOp teamsOp) {
3821+ if (teamsOp.getNumTeamsLower () == blockArg)
3822+ numTeamsLower = hostEvalVar;
3823+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3824+ numTeamsUpper = hostEvalVar;
3825+ else if (teamsOp.getThreadLimit () == blockArg)
3826+ threadLimit = hostEvalVar;
3827+ else
3828+ llvm_unreachable (" unsupported host_eval use" );
3829+ })
3830+ .Case ([&](omp::ParallelOp parallelOp) {
3831+ if (parallelOp.getNumThreads () == blockArg)
3832+ numThreads = hostEvalVar;
3833+ else
3834+ llvm_unreachable (" unsupported host_eval use" );
3835+ })
3836+ .Case ([&](omp::LoopNestOp loopOp) {
3837+ // TODO: Extract bounds and step values.
3838+ })
3839+ .Default ([](Operation *) {
3840+ llvm_unreachable (" unsupported host_eval use" );
3841+ });
3842+ }
3843+ }
3844+ }
3845+
3846+ // / If \p op is of the given type parameter, return it casted to that type.
3847+ // / Otherwise, if its immediate parent operation (or some other higher-level
3848+ // / parent, if \p immediateParent is false) is of that type, return that parent
3849+ // / casted to the given type.
3850+ // /
3851+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3852+ // / type, return a \c null operation.
3853+ template <typename OpTy>
3854+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3855+ if (!op)
3856+ return OpTy ();
3857+
3858+ if (OpTy casted = dyn_cast<OpTy>(op))
3859+ return casted;
3860+
3861+ if (immediateParent)
3862+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3863+
3864+ return op->getParentOfType <OpTy>();
3865+ }
3866+
3867+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3868+ // / values as stated by the corresponding clauses, if constant.
3869+ // /
3870+ // / These default values must be set before the creation of the outlined LLVM
3871+ // / function for the target region, so that they can be used to initialize the
3872+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3873+ static void
3874+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3875+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3876+ bool isTargetDevice) {
3877+ // TODO: Handle constant 'if' clauses.
3878+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3879+
3880+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3881+ if (!isTargetDevice) {
3882+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3883+ threadLimit);
3884+ } else {
3885+ // In the target device, values for these clauses are not passed as
3886+ // host_eval, but instead evaluated prior to entry to the region. This
3887+ // ensures values are mapped and available inside of the target region.
3888+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3889+ numTeamsLower = teamsOp.getNumTeamsLower ();
3890+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3891+ threadLimit = teamsOp.getThreadLimit ();
3892+ }
3893+
3894+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3895+ numThreads = parallelOp.getNumThreads ();
3896+ }
3897+
3898+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3899+ if (auto constOp =
3900+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3901+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3902+ return constAttr.getInt ();
3903+
3904+ return std::nullopt ;
3905+ };
3906+
3907+ // Handle clauses impacting the number of teams.
3908+
3909+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3910+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3911+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
3912+ // clang and set min and max to the same value.
3913+ if (numTeamsUpper) {
3914+ if (auto val = extractConstInteger (numTeamsUpper))
3915+ minTeamsVal = maxTeamsVal = *val;
3916+ } else {
3917+ minTeamsVal = maxTeamsVal = 0 ;
3918+ }
3919+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
3920+ /* immediateParent=*/ true ) ||
3921+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
3922+ /* immediateParent=*/ true )) {
3923+ minTeamsVal = maxTeamsVal = 1 ;
3924+ } else {
3925+ minTeamsVal = maxTeamsVal = -1 ;
3926+ }
3927+
3928+ // Handle clauses impacting the number of threads.
3929+
3930+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
3931+ int32_t &result) {
3932+ if (!clauseValue)
3933+ return ;
3934+
3935+ if (auto val = extractConstInteger (clauseValue))
3936+ result = *val;
3937+
3938+ // Found an applicable clause, so it's not undefined. Mark as unknown
3939+ // because it's not constant.
3940+ if (result < 0 )
3941+ result = 0 ;
3942+ };
3943+
3944+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
3945+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
3946+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
3947+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
3948+
3949+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
3950+ int32_t maxThreadsVal = -1 ;
3951+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3952+ setMaxValueFromClause (numThreads, maxThreadsVal);
3953+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
3954+ /* immediateParent=*/ true ))
3955+ maxThreadsVal = 1 ;
3956+
3957+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
3958+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
3959+ // set.
3960+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
3961+ if (combinedMaxThreadsVal < 0 ||
3962+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
3963+ combinedMaxThreadsVal = teamsThreadLimitVal;
3964+
3965+ if (combinedMaxThreadsVal < 0 ||
3966+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
3967+ combinedMaxThreadsVal = maxThreadsVal;
3968+
3969+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
3970+ attrs.MinTeams = minTeamsVal;
3971+ attrs.MaxTeams .front () = maxTeamsVal;
3972+ attrs.MinThreads = 1 ;
3973+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
3974+ }
3975+
3976+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
3977+ // / passed to the kernel invocation.
3978+ // /
3979+ // / This function must be called only when compiling for the host. Also, it will
3980+ // / only provide correct results if it's called after the body of \c targetOp
3981+ // / has been fully generated.
3982+ static void
3983+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
3984+ LLVM::ModuleTranslation &moduleTranslation,
3985+ omp::TargetOp targetOp,
3986+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
3987+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
3988+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3989+ teamsThreadLimit);
3990+
3991+ // TODO: Handle constant 'if' clauses.
3992+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
3993+ attrs.TargetThreadLimit .front () =
3994+ moduleTranslation.lookupValue (targetThreadLimit);
3995+
3996+ if (numTeamsLower)
3997+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
3998+
3999+ if (numTeamsUpper)
4000+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4001+
4002+ if (teamsThreadLimit)
4003+ attrs.TeamsThreadLimit .front () =
4004+ moduleTranslation.lookupValue (teamsThreadLimit);
4005+
4006+ if (numThreads)
4007+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4008+
4009+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4010+ }
4011+
38034012static LogicalResult
38044013convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
38054014 LLVM::ModuleTranslation &moduleTranslation) {
@@ -3809,12 +4018,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38094018
38104019 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
38114020 bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4021+
38124022 auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4023+ auto blockIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
38134024 auto &targetRegion = targetOp.getRegion ();
38144025 DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
38154026 SmallVector<Value> mapVars = targetOp.getMapVars ();
3816- ArrayRef<BlockArgument> mapBlockArgs =
3817- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4027+ ArrayRef<BlockArgument> mapBlockArgs = blockIface.getMapBlockArgs ();
38184028 llvm::Function *llvmOutlinedFn = nullptr ;
38194029
38204030 // TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3857,7 +4067,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38574067 OperandRange privateVars = targetOp.getPrivateVars ();
38584068 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
38594069 MutableArrayRef<BlockArgument> privateBlockArgs =
3860- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4070+ blockIface .getPrivateBlockArgs ();
38614071
38624072 for (auto [privVar, privatizerNameAttr, privBlockArg] :
38634073 llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
@@ -3936,13 +4146,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39364146 allocaIP, codeGenIP);
39374147 };
39384148
3939- // TODO: Populate default and runtime attributes based on the construct and
3940- // clauses.
3941- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
3942- /* MaxTeams=*/ {-1 }, /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4149+ llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4150+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4151+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4152+
4153+ // Collect host-evaluated values needed to properly launch the kernel from the
4154+ // host.
39434155 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4156+ if (!isTargetDevice)
4157+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4158+
4159+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4160+ // except if they are constants. In any case, map the MLIR block argument to
4161+ // the corresponding LLVM values.
4162+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4163+ ArrayRef<BlockArgument> hostEvalBlockArgs = blockIface.getHostEvalBlockArgs ();
4164+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4165+ llvm::Value *value = moduleTranslation.lookupValue (var);
4166+ moduleTranslation.mapValue (arg, value);
4167+
4168+ if (!llvm::isa<llvm::Constant>(value))
4169+ kernelInput.push_back (value);
4170+ }
39444171
3945- llvm::SmallVector<llvm::Value *, 4 > kernelInput;
39464172 for (size_t i = 0 ; i < mapVars.size (); ++i) {
39474173 // declare target arguments are not passed to kernels as arguments
39484174 // TODO: We currently do not handle cases where a member is explicitly
0 commit comments