Skip to content

Commit 65c7554

Browse files
Update to fix build and tests after rebasing on main.
Also add the if clause to the newly generated omp.task op that encloses the omp.target op.
1 parent 0f71c22 commit 65c7554

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIROpenMPTransforms
99

1010
LINK_LIBS PUBLIC
1111
MLIROpenMPDialect
12+
MLIRArithDialect
1213
MLIRFuncDialect
1314
MLIRIR
1415
MLIRPass

mlir/lib/Dialect/OpenMP/Transforms/OpenMPTaskBasedTarget.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Dialect/OpenMP/Passes.h"
3434

3535
#include "mlir/Dialect/Func/IR/FuncOps.h"
36+
#include "mlir/Dialect/Arith/IR/Arith.h"
3637
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3738
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3839
#include "llvm/Support/Debug.h"
@@ -68,8 +69,14 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
6869

6970
// Step 1: Create a new task op and tack on the dependency from the 'depend'
7071
// clause on it.
72+
Type i1Ty = rewriter.getI1Type();
73+
// mlir::BoolAttr T = rewriter.getBoolAttr(true);
74+
// mlir::BoolAttr F = rewriter.getBoolAttr(false);
7175
omp::TaskOp taskOp = rewriter.create<omp::TaskOp>(
72-
op.getLoc(), /*if_expr*/ Value(),
76+
op.getLoc(),
77+
/*if_expr*/ op.getNowait() ?
78+
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))
79+
: rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)),
7380
/*final_expr*/ Value(),
7481
/*untied*/ UnitAttr(),
7582
/*mergeable*/ UnitAttr(),
@@ -100,9 +107,9 @@ class OmpTaskBasedTargetRewritePattern : public OpRewritePattern<OpTy> {
100107
static void
101108
populateOmpTaskBasedTargetRewritePatterns(RewritePatternSet &patterns) {
102109
patterns.add<OmpTaskBasedTargetRewritePattern<omp::TargetOp>,
103-
OmpTaskBasedTargetRewritePattern<omp::EnterDataOp>,
104-
OmpTaskBasedTargetRewritePattern<omp::UpdateDataOp>,
105-
OmpTaskBasedTargetRewritePattern<omp::ExitDataOp>>(
110+
OmpTaskBasedTargetRewritePattern<omp::TargetEnterDataOp>,
111+
OmpTaskBasedTargetRewritePattern<omp::TargetUpdateOp>,
112+
OmpTaskBasedTargetRewritePattern<omp::TargetExitDataOp>>(
106113
patterns.getContext());
107114
}
108115

mlir/test/Dialect/OpenMP/task-based-target.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// CHECK-LABEL: @omp_target_depend
44
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
55
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
6-
// CHECK: omp.task depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
6+
// CHECK: omp.task if(%false) depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
77
// CHECK: omp.target {
88
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
99
// CHECK: omp.terminator
@@ -14,12 +14,12 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
1414
// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
1515
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
1616
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
17-
// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
18-
// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
19-
// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
20-
%map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
21-
%map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
22-
%map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
17+
// CHECK: [[MAP0:%.*]] = omp.map.info
18+
// CHECK-NEXT: [[MAP1:%.*]] = omp.map.info
19+
// CHECK-NEXT: [[MAP2:%.*]] = omp.map.info
20+
%map_a = omp.map.info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
21+
%map_b = omp.map.info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
22+
%map_c = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
2323

2424
// Do some work on the host that writes to 'a'
2525
omp.task depend(taskdependout -> %a : memref<?xi32>) {
@@ -28,7 +28,7 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
2828
}
2929

3030
// Then map that over to the target
31-
// CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
31+
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
3232
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>)
3333
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
3434

@@ -46,21 +46,21 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
4646
}
4747

4848
// Copy the updated 'a' onto the target
49-
// CHECK: omp.task depend(taskdependin -> [[ARG0]] : memref<?xi32>)
50-
// CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>)
51-
omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
49+
// CHECK: omp.task if(%true) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
50+
// CHECK: omp.target_update nowait motion_entries([[MAP0]] : memref<?xi32>)
51+
omp.target_update motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
5252

5353
// Compute 'c' on the target and copy it back
54-
// CHECK:[[MAP3:%.*]] = omp.map_info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
55-
%map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
56-
// CHECK: omp.task depend(taskdependout -> [[ARG2]] : memref<?xi32>)
54+
// CHECK:[[MAP3:%.*]] = omp.map.info var_ptr([[ARG2]] : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
55+
%map_c_from = omp.map.info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
56+
// CHECK: omp.task if(%false) depend(taskdependout -> [[ARG2]] : memref<?xi32>)
5757
// CHECK: omp.target map_entries([[MAP0]] -> {{%.*}}, [[MAP3]] -> {{%.*}} : memref<?xi32>, memref<?xi32>) {
5858
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
5959
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
6060
"test.foobar"() : ()->()
6161
omp.terminator
6262
}
63-
// CHECK: omp.task depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
63+
// CHECK: omp.task if(%false) depend(taskdependin -> [[ARG2]] : memref<?xi32>) {
6464
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>)
6565
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
6666
return

0 commit comments

Comments
 (0)