@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158
158
if (op.getBare())
159
159
result = todo("ompx_bare");
160
160
};
161
+ auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
+ omp::ClauseCancellationConstructType cancelledDirective =
163
+ op.getCancelDirective();
164
+ if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel)
165
+ result = todo("cancel directive construct type not yet supported");
166
+ };
161
167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
162
168
if (!op.getDependVars().empty() || op.getDependKinds())
163
169
result = todo("depend");
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
248
254
249
255
LogicalResult result = success();
250
256
llvm::TypeSwitch<Operation &>(op)
257
+ .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
251
258
.Case([&](omp::DistributeOp op) {
252
259
checkAllocate(op, result);
253
260
checkDistSchedule(op, result);
@@ -1580,6 +1587,19 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder,
1580
1587
return success();
1581
1588
}
1582
1589
1590
+ /// Returns true if the construct contains omp.cancel or omp.cancellation_point
1591
+ static bool constructIsCancellable(Operation *op) {
1592
+ // omp.cancel must be "closely nested" so it will be visible and not inside of
1593
+ // funcion calls. This is enforced by the verifier.
1594
+ return op
1595
+ ->walk([](Operation *child) {
1596
+ if (mlir::isa<omp::CancelOp>(child))
1597
+ return WalkResult::interrupt();
1598
+ return WalkResult::advance();
1599
+ })
1600
+ .wasInterrupted();
1601
+ }
1602
+
1583
1603
static LogicalResult
1584
1604
convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1585
1605
LLVM::ModuleTranslation &moduleTranslation) {
@@ -2524,8 +2544,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2524
2544
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2525
2545
if (auto bind = opInst.getProcBindKind())
2526
2546
pbKind = getProcBindKind(*bind);
2527
- // TODO: Is the Parallel construct cancellable?
2528
- bool isCancellable = false;
2547
+ bool isCancellable = constructIsCancellable(opInst);
2529
2548
2530
2549
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2531
2550
findAllocaInsertPoint(builder, moduleTranslation);
@@ -2991,6 +3010,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
2991
3010
return success();
2992
3011
}
2993
3012
3013
+ static llvm::omp::Directive convertCancellationConstructType(
3014
+ omp::ClauseCancellationConstructType directive) {
3015
+ switch (directive) {
3016
+ case omp::ClauseCancellationConstructType::Loop:
3017
+ return llvm::omp::Directive::OMPD_for;
3018
+ case omp::ClauseCancellationConstructType::Parallel:
3019
+ return llvm::omp::Directive::OMPD_parallel;
3020
+ case omp::ClauseCancellationConstructType::Sections:
3021
+ return llvm::omp::Directive::OMPD_sections;
3022
+ case omp::ClauseCancellationConstructType::Taskgroup:
3023
+ return llvm::omp::Directive::OMPD_taskgroup;
3024
+ }
3025
+ }
3026
+
3027
+ static LogicalResult
3028
+ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3029
+ LLVM::ModuleTranslation &moduleTranslation) {
3030
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3031
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3032
+
3033
+ if (failed(checkImplementationStatus(*op.getOperation())))
3034
+ return failure();
3035
+
3036
+ llvm::Value *ifCond = nullptr;
3037
+ if (Value ifVar = op.getIfExpr())
3038
+ ifCond = moduleTranslation.lookupValue(ifVar);
3039
+
3040
+ llvm::omp::Directive cancelledDirective =
3041
+ convertCancellationConstructType(op.getCancelDirective());
3042
+
3043
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3044
+ ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3045
+
3046
+ if (failed(handleError(afterIP, *op.getOperation())))
3047
+ return failure();
3048
+
3049
+ builder.restoreIP(afterIP.get());
3050
+
3051
+ return success();
3052
+ }
3053
+
2994
3054
/// Converts an OpenMP Threadprivate operation into LLVM IR using
2995
3055
/// OpenMPIRBuilder.
2996
3056
static LogicalResult
@@ -5421,6 +5481,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5421
5481
.Case([&](omp::AtomicCaptureOp op) {
5422
5482
return convertOmpAtomicCapture(op, builder, moduleTranslation);
5423
5483
})
5484
+ .Case([&](omp::CancelOp op) {
5485
+ return convertOmpCancel(op, builder, moduleTranslation);
5486
+ })
5424
5487
.Case([&](omp::SectionsOp) {
5425
5488
return convertOmpSections(*op, builder, moduleTranslation);
5426
5489
})
0 commit comments