@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174174 if (op.getHint ())
175175 op.emitWarning (" hint clause discarded" );
176176 };
177- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178- if (!op.getHostEvalVars ().empty ())
179- result = todo (" host_eval" );
180- };
181177 auto checkIf = [&todo](auto op, LogicalResult &result) {
182178 if (op.getIfExpr ())
183179 result = todo (" if" );
@@ -224,10 +220,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
224220 op.getReductionSyms ())
225221 result = todo (" reduction" );
226222 };
227- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
228- if (op.getThreadLimit ())
229- result = todo (" thread_limit" );
230- };
231223 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
232224 if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
233225 op.getTaskReductionSyms ())
@@ -290,7 +282,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
290282 checkAllocate (op, result);
291283 checkDevice (op, result);
292284 checkHasDeviceAddr (op, result);
293- checkHostEval (op, result);
285+
286+ // Host evaluated clauses are supported, except for target SPMD loop
287+ // bounds.
288+ for (BlockArgument arg :
289+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
290+ for (Operation *user : arg.getUsers ())
291+ if (isa<omp::LoopNestOp>(user))
292+ result = op.emitError (" not yet implemented: host evaluation of "
293+ " loop bounds in omp.target operation" );
294+
294295 checkIf (op, result);
295296 checkInReduction (op, result);
296297 checkIsDevicePtr (op, result);
@@ -311,7 +312,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
311312 " structures in omp.target operation" );
312313 }
313314 }
314- checkThreadLimit (op, result);
315315 })
316316 .Default ([](Operation &) {
317317 // Assume all clauses for an operation can be translated unless they are
@@ -3815,6 +3815,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38153815 return builder.saveIP ();
38163816}
38173817
3818+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3819+ // / operation and populate output variables with their corresponding host value
3820+ // / (i.e. operand evaluated outside of the target region), based on their uses
3821+ // / inside of the target region.
3822+ // /
3823+ // / Loop bounds and steps are only optionally populated, if output vectors are
3824+ // / provided.
3825+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3826+ Value &numTeamsLower, Value &numTeamsUpper,
3827+ Value &threadLimit) {
3828+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3829+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3830+ blockArgIface.getHostEvalBlockArgs ())) {
3831+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3832+
3833+ for (Operation *user : blockArg.getUsers ()) {
3834+ llvm::TypeSwitch<Operation *>(user)
3835+ .Case ([&](omp::TeamsOp teamsOp) {
3836+ if (teamsOp.getNumTeamsLower () == blockArg)
3837+ numTeamsLower = hostEvalVar;
3838+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3839+ numTeamsUpper = hostEvalVar;
3840+ else if (teamsOp.getThreadLimit () == blockArg)
3841+ threadLimit = hostEvalVar;
3842+ else
3843+ llvm_unreachable (" unsupported host_eval use" );
3844+ })
3845+ .Case ([&](omp::ParallelOp parallelOp) {
3846+ if (parallelOp.getNumThreads () == blockArg)
3847+ numThreads = hostEvalVar;
3848+ else
3849+ llvm_unreachable (" unsupported host_eval use" );
3850+ })
3851+ .Case ([&](omp::LoopNestOp loopOp) {
3852+ // TODO: Extract bounds and step values.
3853+ })
3854+ .Default ([](Operation *) {
3855+ llvm_unreachable (" unsupported host_eval use" );
3856+ });
3857+ }
3858+ }
3859+ }
3860+
3861+ // / If \p op is of the given type parameter, return it casted to that type.
3862+ // / Otherwise, if its immediate parent operation (or some other higher-level
3863+ // / parent, if \p immediateParent is false) is of that type, return that parent
3864+ // / casted to the given type.
3865+ // /
3866+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3867+ // / type, return a \c null operation.
3868+ template <typename OpTy>
3869+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3870+ if (!op)
3871+ return OpTy ();
3872+
3873+ if (OpTy casted = dyn_cast<OpTy>(op))
3874+ return casted;
3875+
3876+ if (immediateParent)
3877+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3878+
3879+ return op->getParentOfType <OpTy>();
3880+ }
3881+
3882+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3883+ // / values as stated by the corresponding clauses, if constant.
3884+ // /
3885+ // / These default values must be set before the creation of the outlined LLVM
3886+ // / function for the target region, so that they can be used to initialize the
3887+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3888+ static void
3889+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3890+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3891+ bool isTargetDevice) {
3892+ // TODO: Handle constant 'if' clauses.
3893+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3894+
3895+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3896+ if (!isTargetDevice) {
3897+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3898+ threadLimit);
3899+ } else {
3900+ // In the target device, values for these clauses are not passed as
3901+ // host_eval, but instead evaluated prior to entry to the region. This
3902+ // ensures values are mapped and available inside of the target region.
3903+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3904+ numTeamsLower = teamsOp.getNumTeamsLower ();
3905+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3906+ threadLimit = teamsOp.getThreadLimit ();
3907+ }
3908+
3909+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3910+ numThreads = parallelOp.getNumThreads ();
3911+ }
3912+
3913+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3914+ if (auto constOp =
3915+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3916+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3917+ return constAttr.getInt ();
3918+
3919+ return std::nullopt ;
3920+ };
3921+
3922+ // Handle clauses impacting the number of teams.
3923+
3924+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3925+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3926+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
3927+ // clang and set min and max to the same value.
3928+ if (numTeamsUpper) {
3929+ if (auto val = extractConstInteger (numTeamsUpper))
3930+ minTeamsVal = maxTeamsVal = *val;
3931+ } else {
3932+ minTeamsVal = maxTeamsVal = 0 ;
3933+ }
3934+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
3935+ /* immediateParent=*/ true ) ||
3936+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
3937+ /* immediateParent=*/ true )) {
3938+ minTeamsVal = maxTeamsVal = 1 ;
3939+ } else {
3940+ minTeamsVal = maxTeamsVal = -1 ;
3941+ }
3942+
3943+ // Handle clauses impacting the number of threads.
3944+
3945+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
3946+ int32_t &result) {
3947+ if (!clauseValue)
3948+ return ;
3949+
3950+ if (auto val = extractConstInteger (clauseValue))
3951+ result = *val;
3952+
3953+ // Found an applicable clause, so it's not undefined. Mark as unknown
3954+ // because it's not constant.
3955+ if (result < 0 )
3956+ result = 0 ;
3957+ };
3958+
3959+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
3960+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
3961+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
3962+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
3963+
3964+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
3965+ int32_t maxThreadsVal = -1 ;
3966+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3967+ setMaxValueFromClause (numThreads, maxThreadsVal);
3968+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
3969+ /* immediateParent=*/ true ))
3970+ maxThreadsVal = 1 ;
3971+
3972+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
3973+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
3974+ // set.
3975+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
3976+ if (combinedMaxThreadsVal < 0 ||
3977+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
3978+ combinedMaxThreadsVal = teamsThreadLimitVal;
3979+
3980+ if (combinedMaxThreadsVal < 0 ||
3981+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
3982+ combinedMaxThreadsVal = maxThreadsVal;
3983+
3984+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
3985+ attrs.MinTeams = minTeamsVal;
3986+ attrs.MaxTeams .front () = maxTeamsVal;
3987+ attrs.MinThreads = 1 ;
3988+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
3989+ }
3990+
3991+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
3992+ // / passed to the kernel invocation.
3993+ // /
3994+ // / This function must be called only when compiling for the host. Also, it will
3995+ // / only provide correct results if it's called after the body of \c targetOp
3996+ // / has been fully generated.
3997+ static void
3998+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
3999+ LLVM::ModuleTranslation &moduleTranslation,
4000+ omp::TargetOp targetOp,
4001+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4002+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4003+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4004+ teamsThreadLimit);
4005+
4006+ // TODO: Handle constant 'if' clauses.
4007+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
4008+ attrs.TargetThreadLimit .front () =
4009+ moduleTranslation.lookupValue (targetThreadLimit);
4010+
4011+ if (numTeamsLower)
4012+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
4013+
4014+ if (numTeamsUpper)
4015+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4016+
4017+ if (teamsThreadLimit)
4018+ attrs.TeamsThreadLimit .front () =
4019+ moduleTranslation.lookupValue (teamsThreadLimit);
4020+
4021+ if (numThreads)
4022+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4023+
4024+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4025+ }
4026+
38184027static LogicalResult
38194028convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
38204029 LLVM::ModuleTranslation &moduleTranslation) {
@@ -3824,12 +4033,13 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38244033
38254034 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
38264035 bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4036+
38274037 auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4038+ auto blockIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
38284039 auto &targetRegion = targetOp.getRegion ();
38294040 DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
38304041 SmallVector<Value> mapVars = targetOp.getMapVars ();
3831- ArrayRef<BlockArgument> mapBlockArgs =
3832- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4042+ ArrayRef<BlockArgument> mapBlockArgs = blockIface.getMapBlockArgs ();
38334043 llvm::Function *llvmOutlinedFn = nullptr ;
38344044
38354045 // TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3872,7 +4082,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38724082 OperandRange privateVars = targetOp.getPrivateVars ();
38734083 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
38744084 MutableArrayRef<BlockArgument> privateBlockArgs =
3875- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4085+ blockIface .getPrivateBlockArgs ();
38764086
38774087 for (auto [privVar, privatizerNameAttr, privBlockArg] :
38784088 llvm::zip_equal (privateVars, *privateSyms, privateBlockArgs)) {
@@ -3951,13 +4161,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39514161 allocaIP, codeGenIP);
39524162 };
39534163
3954- // TODO: Populate default and runtime attributes based on the construct and
3955- // clauses.
3956- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
3957- /* MaxTeams=*/ {-1 }, /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4164+ llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4165+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4166+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
4167+
4168+ // Collect host-evaluated values needed to properly launch the kernel from the
4169+ // host.
39584170 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4171+ if (!isTargetDevice)
4172+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4173+
4174+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4175+ // except if they are constants. In any case, map the MLIR block argument to
4176+ // the corresponding LLVM values.
4177+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4178+ ArrayRef<BlockArgument> hostEvalBlockArgs = blockIface.getHostEvalBlockArgs ();
4179+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4180+ llvm::Value *value = moduleTranslation.lookupValue (var);
4181+ moduleTranslation.mapValue (arg, value);
4182+
4183+ if (!llvm::isa<llvm::Constant>(value))
4184+ kernelInput.push_back (value);
4185+ }
39594186
3960- llvm::SmallVector<llvm::Value *, 4 > kernelInput;
39614187 for (size_t i = 0 ; i < mapVars.size (); ++i) {
39624188 // declare target arguments are not passed to kernels as arguments
39634189 // TODO: We currently do not handle cases where a member is explicitly
0 commit comments