@@ -174,10 +174,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174174 if (op.getHint ())
175175 op.emitWarning (" hint clause discarded" );
176176 };
177- auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178- if (!op.getHostEvalVars ().empty ())
179- result = todo (" host_eval" );
180- };
181177 auto checkIf = [&todo](auto op, LogicalResult &result) {
182178 if (op.getIfExpr ())
183179 result = todo (" if" );
@@ -224,10 +220,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
224220 op.getReductionSyms ())
225221 result = todo (" reduction" );
226222 };
227- auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
228- if (op.getThreadLimit ())
229- result = todo (" thread_limit" );
230- };
231223 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
232224 if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
233225 op.getTaskReductionSyms ())
@@ -289,7 +281,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
289281 checkBare (op, result);
290282 checkDevice (op, result);
291283 checkHasDeviceAddr (op, result);
292- checkHostEval (op, result);
284+
285+ // Host evaluated clauses are supported, except for target SPMD loop
286+ // bounds.
287+ for (BlockArgument arg :
288+ cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs ())
289+ for (Operation *user : arg.getUsers ())
290+ if (isa<omp::LoopNestOp>(user))
291+ result = op.emitError (" not yet implemented: host evaluation of "
292+ " loop bounds in omp.target operation" );
293+
293294 checkIf (op, result);
294295 checkInReduction (op, result);
295296 checkIsDevicePtr (op, result);
@@ -306,7 +307,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
306307 result = todo (" firstprivate" );
307308 }
308309 }
309- checkThreadLimit (op, result);
310310 })
311311 .Default ([](Operation &) {
312312 // Assume all clauses for an operation can be translated unless they are
@@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38893889 return builder.saveIP ();
38903890}
38913891
3892+ // / Follow uses of `host_eval`-defined block arguments of the given `omp.target`
3893+ // / operation and populate output variables with their corresponding host value
3894+ // / (i.e. operand evaluated outside of the target region), based on their uses
3895+ // / inside of the target region.
3896+ // /
3897+ // / Loop bounds and steps are only optionally populated, if output vectors are
3898+ // / provided.
3899+ static void extractHostEvalClauses (omp::TargetOp targetOp, Value &numThreads,
3900+ Value &numTeamsLower, Value &numTeamsUpper,
3901+ Value &threadLimit) {
3902+ auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3903+ for (auto item : llvm::zip_equal (targetOp.getHostEvalVars (),
3904+ blockArgIface.getHostEvalBlockArgs ())) {
3905+ Value hostEvalVar = std::get<0 >(item), blockArg = std::get<1 >(item);
3906+
3907+ for (Operation *user : blockArg.getUsers ()) {
3908+ llvm::TypeSwitch<Operation *>(user)
3909+ .Case ([&](omp::TeamsOp teamsOp) {
3910+ if (teamsOp.getNumTeamsLower () == blockArg)
3911+ numTeamsLower = hostEvalVar;
3912+ else if (teamsOp.getNumTeamsUpper () == blockArg)
3913+ numTeamsUpper = hostEvalVar;
3914+ else if (teamsOp.getThreadLimit () == blockArg)
3915+ threadLimit = hostEvalVar;
3916+ else
3917+ llvm_unreachable (" unsupported host_eval use" );
3918+ })
3919+ .Case ([&](omp::ParallelOp parallelOp) {
3920+ if (parallelOp.getNumThreads () == blockArg)
3921+ numThreads = hostEvalVar;
3922+ else
3923+ llvm_unreachable (" unsupported host_eval use" );
3924+ })
3925+ .Case ([&](omp::LoopNestOp loopOp) {
3926+ // TODO: Extract bounds and step values.
3927+ })
3928+ .Default ([](Operation *) {
3929+ llvm_unreachable (" unsupported host_eval use" );
3930+ });
3931+ }
3932+ }
3933+ }
3934+
3935+ // / If \p op is of the given type parameter, return it casted to that type.
3936+ // / Otherwise, if its immediate parent operation (or some other higher-level
3937+ // / parent, if \p immediateParent is false) is of that type, return that parent
3938+ // / casted to the given type.
3939+ // /
3940+ // / If \p op is \c null or neither it or its parent(s) are of the specified
3941+ // / type, return a \c null operation.
3942+ template <typename OpTy>
3943+ static OpTy castOrGetParentOfType (Operation *op, bool immediateParent = false ) {
3944+ if (!op)
3945+ return OpTy ();
3946+
3947+ if (OpTy casted = dyn_cast<OpTy>(op))
3948+ return casted;
3949+
3950+ if (immediateParent)
3951+ return dyn_cast_if_present<OpTy>(op->getParentOp ());
3952+
3953+ return op->getParentOfType <OpTy>();
3954+ }
3955+
3956+ // / Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
3957+ // / values as stated by the corresponding clauses, if constant.
3958+ // /
3959+ // / These default values must be set before the creation of the outlined LLVM
3960+ // / function for the target region, so that they can be used to initialize the
3961+ // / corresponding global `ConfigurationEnvironmentTy` structure.
3962+ static void
3963+ initTargetDefaultAttrs (omp::TargetOp targetOp,
3964+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
3965+ bool isTargetDevice) {
3966+ // TODO: Handle constant 'if' clauses.
3967+ Operation *capturedOp = targetOp.getInnermostCapturedOmpOp ();
3968+
3969+ Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
3970+ if (!isTargetDevice) {
3971+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
3972+ threadLimit);
3973+ } else {
3974+ // In the target device, values for these clauses are not passed as
3975+ // host_eval, but instead evaluated prior to entry to the region. This
3976+ // ensures values are mapped and available inside of the target region.
3977+ if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
3978+ numTeamsLower = teamsOp.getNumTeamsLower ();
3979+ numTeamsUpper = teamsOp.getNumTeamsUpper ();
3980+ threadLimit = teamsOp.getThreadLimit ();
3981+ }
3982+
3983+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
3984+ numThreads = parallelOp.getNumThreads ();
3985+ }
3986+
3987+ auto extractConstInteger = [](Value value) -> std::optional<int64_t > {
3988+ if (auto constOp =
3989+ dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp ()))
3990+ if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue ()))
3991+ return constAttr.getInt ();
3992+
3993+ return std::nullopt ;
3994+ };
3995+
3996+ // Handle clauses impacting the number of teams.
3997+
3998+ int32_t minTeamsVal = 1 , maxTeamsVal = -1 ;
3999+ if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
4000+ // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
4001+ // clang and set min and max to the same value.
4002+ if (numTeamsUpper) {
4003+ if (auto val = extractConstInteger (numTeamsUpper))
4004+ minTeamsVal = maxTeamsVal = *val;
4005+ } else {
4006+ minTeamsVal = maxTeamsVal = 0 ;
4007+ }
4008+ } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
4009+ /* immediateParent=*/ true ) ||
4010+ castOrGetParentOfType<omp::SimdOp>(capturedOp,
4011+ /* immediateParent=*/ true )) {
4012+ minTeamsVal = maxTeamsVal = 1 ;
4013+ } else {
4014+ minTeamsVal = maxTeamsVal = -1 ;
4015+ }
4016+
4017+ // Handle clauses impacting the number of threads.
4018+
4019+ auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
4020+ int32_t &result) {
4021+ if (!clauseValue)
4022+ return ;
4023+
4024+ if (auto val = extractConstInteger (clauseValue))
4025+ result = *val;
4026+
4027+ // Found an applicable clause, so it's not undefined. Mark as unknown
4028+ // because it's not constant.
4029+ if (result < 0 )
4030+ result = 0 ;
4031+ };
4032+
4033+ // Extract 'thread_limit' clause from 'target' and 'teams' directives.
4034+ int32_t targetThreadLimitVal = -1 , teamsThreadLimitVal = -1 ;
4035+ setMaxValueFromClause (targetOp.getThreadLimit (), targetThreadLimitVal);
4036+ setMaxValueFromClause (threadLimit, teamsThreadLimitVal);
4037+
4038+ // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
4039+ int32_t maxThreadsVal = -1 ;
4040+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
4041+ setMaxValueFromClause (numThreads, maxThreadsVal);
4042+ else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
4043+ /* immediateParent=*/ true ))
4044+ maxThreadsVal = 1 ;
4045+
4046+ // For max values, < 0 means unset, == 0 means set but unknown. Select the
4047+ // minimum value between 'max_threads' and 'thread_limit' clauses that were
4048+ // set.
4049+ int32_t combinedMaxThreadsVal = targetThreadLimitVal;
4050+ if (combinedMaxThreadsVal < 0 ||
4051+ (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
4052+ combinedMaxThreadsVal = teamsThreadLimitVal;
4053+
4054+ if (combinedMaxThreadsVal < 0 ||
4055+ (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
4056+ combinedMaxThreadsVal = maxThreadsVal;
4057+
4058+ // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4059+ attrs.MinTeams = minTeamsVal;
4060+ attrs.MaxTeams .front () = maxTeamsVal;
4061+ attrs.MinThreads = 1 ;
4062+ attrs.MaxThreads .front () = combinedMaxThreadsVal;
4063+ }
4064+
4065+ // / Gather LLVM runtime values for all clauses evaluated in the host that are
4066+ // / passed to the kernel invocation.
4067+ // /
4068+ // / This function must be called only when compiling for the host. Also, it will
4069+ // / only provide correct results if it's called after the body of \c targetOp
4070+ // / has been fully generated.
4071+ static void
4072+ initTargetRuntimeAttrs (llvm::IRBuilderBase &builder,
4073+ LLVM::ModuleTranslation &moduleTranslation,
4074+ omp::TargetOp targetOp,
4075+ llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4076+ Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
4077+ extractHostEvalClauses (targetOp, numThreads, numTeamsLower, numTeamsUpper,
4078+ teamsThreadLimit);
4079+
4080+ // TODO: Handle constant 'if' clauses.
4081+ if (Value targetThreadLimit = targetOp.getThreadLimit ())
4082+ attrs.TargetThreadLimit .front () =
4083+ moduleTranslation.lookupValue (targetThreadLimit);
4084+
4085+ if (numTeamsLower)
4086+ attrs.MinTeams = moduleTranslation.lookupValue (numTeamsLower);
4087+
4088+ if (numTeamsUpper)
4089+ attrs.MaxTeams .front () = moduleTranslation.lookupValue (numTeamsUpper);
4090+
4091+ if (teamsThreadLimit)
4092+ attrs.TeamsThreadLimit .front () =
4093+ moduleTranslation.lookupValue (teamsThreadLimit);
4094+
4095+ if (numThreads)
4096+ attrs.MaxThreads = moduleTranslation.lookupValue (numThreads);
4097+
4098+ // TODO: Populate attrs.LoopTripCount if it is target SPMD.
4099+ }
4100+
38924101static LogicalResult
38934102convertOmpTarget (Operation &opInst, llvm::IRBuilderBase &builder,
38944103 LLVM::ModuleTranslation &moduleTranslation) {
@@ -3898,7 +4107,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38984107
38994108 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
39004109 bool isTargetDevice = ompBuilder->Config .isTargetDevice ();
4110+
39014111 auto parentFn = opInst.getParentOfType <LLVM::LLVMFuncOp>();
4112+ auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
39024113 auto &targetRegion = targetOp.getRegion ();
39034114 // Holds the private vars that have been mapped along with the block argument
39044115 // that corresponds to the MapInfoOp corresponding to the private var in
@@ -3913,8 +4124,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39134124 llvm::DenseMap<Value, Value> mappedPrivateVars;
39144125 DataLayout dl = DataLayout (opInst.getParentOfType <ModuleOp>());
39154126 SmallVector<Value> mapVars = targetOp.getMapVars ();
3916- ArrayRef<BlockArgument> mapBlockArgs =
3917- cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs ();
4127+ ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs ();
39184128 llvm::Function *llvmOutlinedFn = nullptr ;
39194129
39204130 // TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3928,7 +4138,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39284138 // to quickly look up the corresponding map variable, if any for each
39294139 // private variable.
39304140 if (!targetOp.getPrivateVars ().empty () && !targetOp.getMapVars ().empty ()) {
3931- auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
39324141 OperandRange privateVars = targetOp.getPrivateVars ();
39334142 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms ();
39344143 std::optional<DenseI64ArrayAttr> privateMapIndices =
@@ -4002,7 +4211,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40024211 // Do privatization after moduleTranslation has already recorded
40034212 // mapped values.
40044213 MutableArrayRef<BlockArgument> privateBlockArgs =
4005- cast<omp::BlockArgOpenMPOpInterface>(opInst) .getPrivateBlockArgs ();
4214+ argIface .getPrivateBlockArgs ();
40064215 SmallVector<mlir::Value> mlirPrivateVars;
40074216 SmallVector<llvm::Value *> llvmPrivateVars;
40084217 SmallVector<omp::PrivateClauseOp> privateDecls;
@@ -4085,14 +4294,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40854294 allocaIP, codeGenIP);
40864295 };
40874296
4088- // TODO: Populate default and runtime attributes based on the construct and
4089- // clauses.
40904297 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
4091- llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
4092- /* ExecFlags=*/ llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, /* MaxTeams=*/ {-1 },
4093- /* MinTeams=*/ 0 , /* MaxThreads=*/ {0 }, /* MinThreads=*/ 0 };
4298+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4299+ initTargetDefaultAttrs (targetOp, defaultAttrs, isTargetDevice);
40944300
4301+ // Collect host-evaluated values needed to properly launch the kernel from the
4302+ // host.
4303+ if (!isTargetDevice)
4304+ initTargetRuntimeAttrs (builder, moduleTranslation, targetOp, runtimeAttrs);
4305+
4306+ // Pass host-evaluated values as parameters to the kernel / host fallback,
4307+ // except if they are constants. In any case, map the MLIR block argument to
4308+ // the corresponding LLVM values.
40954309 llvm::SmallVector<llvm::Value *, 4 > kernelInput;
4310+ SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars ();
4311+ ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs ();
4312+ for (auto [arg, var] : llvm::zip_equal (hostEvalBlockArgs, hostEvalVars)) {
4313+ llvm::Value *value = moduleTranslation.lookupValue (var);
4314+ moduleTranslation.mapValue (arg, value);
4315+
4316+ if (!llvm::isa<llvm::Constant>(value))
4317+ kernelInput.push_back (value);
4318+ }
4319+
40964320 for (size_t i = 0 ; i < mapVars.size (); ++i) {
40974321 // declare target arguments are not passed to kernels as arguments
40984322 // TODO: We currently do not handle cases where a member is explicitly
0 commit comments