Skip to content

Commit 7b07840

Browse files
committed
[MLIR][NVVM] Fix the lowering of mbarrier.test.wait
PR llvm#165993 broke the lowering of the `test.wait` Op. This patch fixes the issue and adds tests to verify the lowering to intrinsics for all mbarrier Ops, ensuring similar regressions are caught in the future. Additionally, the `cp-async-mbarrier` test is moved to the `mbarriers.mlir` test file to keep all related tests together. Signed-off-by: Durgadoss R <[email protected]>
1 parent c1dc064 commit 7b07840

File tree

3 files changed

+117
-14
lines changed

3 files changed

+117
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
949949
}];
950950

951951
string llvmBuilder = [{
952-
auto [id, args] = NVVM::MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
952+
auto [id, args] = NVVM::MBarrierTestWaitOp::getIntrinsicIDAndArgs(
953953
*op, moduleTranslation, builder);
954954
$res = createIntrinsicCall(builder, id, args);
955955
}];
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
4+
// CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) {
5+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1)
6+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1)
7+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0)
8+
// CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0)
9+
// CHECK-NEXT: ret void
10+
// CHECK-NEXT: }
11+
nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
12+
nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
13+
nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
14+
nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
15+
llvm.return
16+
}
17+
18+
llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) {
19+
// CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) {
20+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
21+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2)
22+
// CHECK-NEXT: ret void
23+
// CHECK-NEXT: }
24+
%count = nvvm.read.ptx.sreg.ntid.x : i32
25+
nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
26+
llvm.return
27+
}
28+
29+
llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
30+
// CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) {
31+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
32+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2)
33+
// CHECK-NEXT: ret void
34+
// CHECK-NEXT: }
35+
%count = nvvm.read.ptx.sreg.ntid.x : i32
36+
nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32
37+
llvm.return
38+
}
39+
40+
llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) {
41+
// CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) {
42+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0)
43+
// CHECK-NEXT: ret void
44+
// CHECK-NEXT: }
45+
nvvm.mbarrier.inval %barrier : !llvm.ptr
46+
llvm.return
47+
}
48+
49+
llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
50+
// CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) {
51+
// CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0)
52+
// CHECK-NEXT: ret void
53+
// CHECK-NEXT: }
54+
nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
55+
llvm.return
56+
}
57+
58+
llvm.func @mbarrier_arrive(%barrier: !llvm.ptr) {
59+
// CHECK-LABEL: define void @mbarrier_arrive(ptr %0) {
60+
// CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive(ptr %0)
61+
// CHECK-NEXT: ret void
62+
// CHECK-NEXT: }
63+
%0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64
64+
llvm.return
65+
}
66+
67+
llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
68+
// CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0) {
69+
// CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive.shared(ptr addrspace(3) %0)
70+
// CHECK-NEXT: ret void
71+
// CHECK-NEXT: }
72+
%0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64
73+
llvm.return
74+
}
75+
76+
llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
77+
// CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) {
78+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
79+
// CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2)
80+
// CHECK-NEXT: ret void
81+
// CHECK-NEXT: }
82+
%count = nvvm.read.ptx.sreg.ntid.x : i32
83+
%0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
84+
llvm.return
85+
}
86+
87+
llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
88+
// CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) {
89+
// CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
90+
// CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2)
91+
// CHECK-NEXT: ret void
92+
// CHECK-NEXT: }
93+
%count = nvvm.read.ptx.sreg.ntid.x : i32
94+
%0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
95+
llvm.return
96+
}
97+
98+
llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
99+
// CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
100+
// CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
101+
// CHECK-NEXT: ret i1 %3
102+
// CHECK-NEXT: }
103+
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
104+
llvm.return %isComplete : i1
105+
}
106+
107+
llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
108+
// CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
109+
// CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
110+
// CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
111+
// CHECK-NEXT: ret void
112+
// CHECK-NEXT: }
113+
%count = nvvm.read.ptx.sreg.ntid.x : i32
114+
%isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
115+
llvm.return
116+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -531,19 +531,6 @@ llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
531531
llvm.return
532532
}
533533

534-
// CHECK-LABEL: @cp_async_mbarrier_arrive
535-
llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
536-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})
537-
nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
538-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %{{.*}})
539-
nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
540-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %{{.*}})
541-
nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
542-
// CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %{{.*}})
543-
nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
544-
llvm.return
545-
}
546-
547534
// CHECK-LABEL: @llvm_nvvm_setmaxregister
548535
llvm.func @llvm_nvvm_setmaxregister() {
549536
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)

0 commit comments

Comments
 (0)