Skip to content

Commit b56669c

Browse files
committed
[MLIR][NVVM] Fix the lowering of mbarrier.test.wait
PR #165993 broke the lowering of the `test.wait` Op. This patch fixes the issue and adds tests to verify the lowering to intrinsics for all mbarrier Ops, ensuring similar regressions are caught in the future. Additionally, the `cp-async-mbarrier` test is moved to the `mbarriers.mlir` test file to keep all related tests together. Signed-off-by: Durgadoss R <[email protected]>
1 parent c1dc064 commit b56669c

File tree

3 files changed

+118
-14
lines changed

3 files changed

+118
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
949949
}];
950950

951951
string llvmBuilder = [{
952-
auto [id, args] = NVVM::MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
952+
auto [id, args] = NVVM::MBarrierTestWaitOp::getIntrinsicIDAndArgs(
953953
*op, moduleTranslation, builder);
954954
$res = createIntrinsicCall(builder, id, args);
955955
}];
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
4+
// CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) {
5+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1)
6+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1)
7+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0)
8+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0)
9+
// CHECK-NEXT: ret void
10+
// CHECK-NEXT: }
11+
nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
12+
nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
13+
nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
14+
nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
15+
llvm.return
16+
}
17+
18+
llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) {
19+
// CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) {
20+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
21+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2)
22+
// CHECK-NEXT: ret void
23+
// CHECK-NEXT: }
24+
%count = nvvm.read.ptx.sreg.ntid.x : i32
25+
nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
26+
llvm.return
27+
}
28+
29+
30+
llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
31+
// CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) {
32+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
33+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2)
34+
// CHECK-NEXT: ret void
35+
// CHECK-NEXT: }
36+
%count = nvvm.read.ptx.sreg.ntid.x : i32
37+
nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32
38+
llvm.return
39+
}
40+
41+
llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) {
42+
// CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) {
43+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0)
44+
// CHECK-NEXT: ret void
45+
// CHECK-NEXT: }
46+
nvvm.mbarrier.inval %barrier : !llvm.ptr
47+
llvm.return
48+
}
49+
50+
llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
51+
// CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) {
52+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0)
53+
// CHECK-NEXT: ret void
54+
// CHECK-NEXT: }
55+
nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
56+
llvm.return
57+
}
58+
59+
llvm.func @mbarrier_arrive(%barrier: !llvm.ptr) {
60+
// CHECK-LABEL: define void @mbarrier_arrive(ptr %0) {
61+
// CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive(ptr %0)
62+
// CHECK-NEXT: ret void
63+
// CHECK-NEXT: }
64+
%0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64
65+
llvm.return
66+
}
67+
68+
llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
69+
// CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0) {
70+
// CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive.shared(ptr addrspace(3) %0)
71+
// CHECK-NEXT: ret void
72+
// CHECK-NEXT: }
73+
%0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64
74+
llvm.return
75+
}
76+
77+
llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
78+
// CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) {
79+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
80+
// CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2)
81+
// CHECK-NEXT: ret void
82+
// CHECK-NEXT: }
83+
%count = nvvm.read.ptx.sreg.ntid.x : i32
84+
%0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
85+
llvm.return
86+
}
87+
88+
llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
89+
// CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) {
90+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
91+
// CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2)
92+
// CHECK-NEXT: ret void
93+
// CHECK-NEXT: }
94+
%count = nvvm.read.ptx.sreg.ntid.x : i32
95+
%0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
96+
llvm.return
97+
}
98+
99+
llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
100+
// CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
101+
// CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
102+
// CHECK-NEXT: ret i1 %3
103+
// CHECK-NEXT: }
104+
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
105+
llvm.return %isComplete : i1
106+
}
107+
108+
llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
109+
// CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
110+
// CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
111+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
112+
// CHECK-NEXT: ret void
113+
// CHECK-NEXT: }
114+
%count = nvvm.read.ptx.sreg.ntid.x : i32
115+
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
116+
llvm.return
117+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -531,19 +531,6 @@ llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
531531
llvm.return
532532
}
533533

534-
// CHECK-LABEL: @cp_async_mbarrier_arrive
535-
llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
536-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})
537-
nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
538-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %{{.*}})
539-
nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
540-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %{{.*}})
541-
nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
542-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %{{.*}})
543-
nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
544-
llvm.return
545-
}
546-
547534
// CHECK-LABEL: @llvm_nvvm_setmaxregister
548535
llvm.func @llvm_nvvm_setmaxregister() {
549536
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)

0 commit comments

Comments
 (0)