@@ -23,11 +23,16 @@ def debug_option(request):
2323 return request .param
2424
2525
26- def get_kernel_ir (sycl_queue , fn , sig , debug = None ):
27- kernel = compiler . compile_kernel (
28- sycl_queue , fn .py_func , sig , None , debug = debug
26+ def get_kernel_ir (fn , sig , debug = None ):
27+ kernel = dpex . core . kernel_interface . spirv_kernel . SpirvKernel (
28+ fn , fn .__name__
2929 )
30- return kernel .assembly
30+ kernel .compile (
31+ arg_types = sig ,
32+ debug = debug ,
33+ extra_compile_flags = None ,
34+ )
35+ return kernel .llvm_module
3136
3237
3338def make_check (ir , val_to_search ):
@@ -45,15 +50,11 @@ def test_debug_flag_generates_ir_with_debuginfo(debug_option):
4550 Check debug info is emitting to IR if debug parameter is set to True
4651 """
4752
48- @dpex .kernel
4953 def foo (x ):
5054 x = 1 # noqa
5155
52- sycl_queue = dpctl .get_current_queue ()
5356 sig = (types .int32 ,)
54-
55- kernel_ir = get_kernel_ir (sycl_queue , foo , sig , debug = debug_option )
56-
57+ kernel_ir = get_kernel_ir (foo , sig , debug = debug_option )
5758 tag = "!dbg"
5859
5960 if debug_option :
@@ -68,7 +69,6 @@ def test_debug_info_locals_vars_on_no_opt():
6869 if debug parameter is set to True and optimization is O0
6970 """
7071
71- @dpex .kernel
7272 def foo (var_a , var_b , var_c ):
7373 i = dpex .get_global_id (0 )
7474 var_c [i ] = var_a [i ] + var_b [i ]
@@ -79,16 +79,14 @@ def foo(var_a, var_b, var_c):
7979 '!DILocalVariable(name: "var_c"' ,
8080 '!DILocalVariable(name: "i"' ,
8181 ]
82-
83- sycl_queue = dpctl .get_current_queue ()
8482 sig = (
8583 npytypes_array_to_dpex_array (types .float32 [:]),
8684 npytypes_array_to_dpex_array (types .float32 [:]),
8785 npytypes_array_to_dpex_array (types .float32 [:]),
8886 )
8987
9088 with override_config ("OPT" , 0 ):
91- kernel_ir = get_kernel_ir (sycl_queue , foo , sig , debug = True )
89+ kernel_ir = get_kernel_ir (foo , sig , debug = True )
9290
9391 for tag in ir_tags :
9492 assert tag in kernel_ir
@@ -100,7 +98,6 @@ def test_debug_kernel_local_vars_in_ir():
10098 created in kernel
10199 """
102100
103- @dpex .kernel
104101 def foo (arr ):
105102 index = dpex .get_global_id (0 )
106103 local_d = 9 * 99 + 5
@@ -110,11 +107,8 @@ def foo(arr):
110107 '!DILocalVariable(name: "index"' ,
111108 '!DILocalVariable(name: "local_d"' ,
112109 ]
113-
114- sycl_queue = dpctl .get_current_queue ()
115110 sig = (npytypes_array_to_dpex_array (types .float32 [:]),)
116-
117- kernel_ir = get_kernel_ir (sycl_queue , foo , sig , debug = True )
111+ kernel_ir = get_kernel_ir (foo , sig , debug = True )
118112
119113 for tag in ir_tags :
120114 assert tag in kernel_ir
@@ -140,16 +134,13 @@ def data_parallel_sum(a, b, c):
140134 r'\!DISubprogram\(name: ".*data_parallel_sum"' ,
141135 ]
142136
143- sycl_queue = dpctl .get_current_queue ()
144137 sig = (
145138 npytypes_array_to_dpex_array (types .float32 [:]),
146139 npytypes_array_to_dpex_array (types .float32 [:]),
147140 npytypes_array_to_dpex_array (types .float32 [:]),
148141 )
149142
150- kernel_ir = get_kernel_ir (
151- sycl_queue , data_parallel_sum , sig , debug = debug_option
152- )
143+ kernel_ir = get_kernel_ir (data_parallel_sum , sig , debug = debug_option )
153144
154145 for tag in ir_tags :
155146 assert debug_option == make_check (kernel_ir , tag )
@@ -175,22 +166,20 @@ def data_parallel_sum(a, b, c):
175166 r'\!DISubprogram\(name: ".*data_parallel_sum"' ,
176167 ]
177168
178- sycl_queue = dpctl .get_current_queue ()
179169 sig = (
180170 npytypes_array_to_dpex_array (types .float32 [:]),
181171 npytypes_array_to_dpex_array (types .float32 [:]),
182172 npytypes_array_to_dpex_array (types .float32 [:]),
183173 )
184174
185175 with override_config ("DEBUGINFO_DEFAULT" , int (debug_option )):
186- kernel_ir = get_kernel_ir (sycl_queue , data_parallel_sum , sig )
176+ kernel_ir = get_kernel_ir (data_parallel_sum , sig )
187177
188178 for tag in ir_tags :
189179 assert debug_option == make_check (kernel_ir , tag )
190180
191181
192182def test_debuginfo_DISubprogram_linkageName ():
193- @dpex .kernel
194183 def func (a , b ):
195184 i = dpex .get_global_id (0 )
196185 b [i ] = a [i ]
@@ -199,20 +188,18 @@ def func(a, b):
199188 r'\!DISubprogram\(.*linkageName: ".*e4func.*"' ,
200189 ]
201190
202- sycl_queue = dpctl .get_current_queue ()
203191 sig = (
204192 npytypes_array_to_dpex_array (types .float32 [:]),
205193 npytypes_array_to_dpex_array (types .float32 [:]),
206194 )
207195
208- kernel_ir = get_kernel_ir (sycl_queue , func , sig , debug = True )
196+ kernel_ir = get_kernel_ir (func , sig , debug = True )
209197
210198 for tag in ir_tags :
211199 assert make_check (kernel_ir , tag )
212200
213201
214202def test_debuginfo_DICompileUnit_language_and_producer ():
215- @dpex .kernel
216203 def func (a , b ):
217204 i = dpex .get_global_id (0 )
218205 b [i ] = a [i ]
@@ -222,13 +209,12 @@ def func(a, b):
222209 r'\!DICompileUnit\(.*producer: "numba-dpex"' ,
223210 ]
224211
225- sycl_queue = dpctl .get_current_queue ()
226212 sig = (
227213 npytypes_array_to_dpex_array (types .float32 [:]),
228214 npytypes_array_to_dpex_array (types .float32 [:]),
229215 )
230216
231- kernel_ir = get_kernel_ir (sycl_queue , func , sig , debug = True )
217+ kernel_ir = get_kernel_ir (func , sig , debug = True )
232218
233219 for tag in ir_tags :
234220 assert make_check (kernel_ir , tag )
0 commit comments