@@ -14,25 +14,26 @@ def numba_type_to_dpctl_typenum(context, ty):
1414 """
1515
1616 if dpctl_sem_version >= (0 , 17 , 0 ):
17- # FIXME change to imports from a dpctl enum/class rather than
18- # hard coding these numbers.
17+ from dpctl ._sycl_queue import kernel_arg_type as kargty
1918
2019 if ty == types .boolean :
21- return context .get_constant (types .int32 , 1 )
20+ return context .get_constant (types .int32 , kargty . dpctl_uint8 . value )
2221 elif ty == types .int32 or isinstance (ty , types .scalars .IntegerLiteral ):
23- return context .get_constant (types .int32 , 4 )
22+ return context .get_constant (types .int32 , kargty . dpctl_int32 . value )
2423 elif ty == types .uint32 :
25- return context .get_constant (types .int32 , 5 )
24+ return context .get_constant (types .int32 , kargty . dpctl_uint32 . value )
2625 elif ty == types .int64 :
27- return context .get_constant (types .int32 , 6 )
26+ return context .get_constant (types .int32 , kargty . dpctl_int64 . value )
2827 elif ty == types .uint64 :
29- return context .get_constant (types .int32 , 7 )
28+ return context .get_constant (types .int32 , kargty . dpctl_uint64 . value )
3029 elif ty == types .float32 :
31- return context .get_constant (types .int32 , 8 )
30+ return context .get_constant (types .int32 , kargty . dpctl_float32 . value )
3231 elif ty == types .float64 :
33- return context .get_constant (types .int32 , 9 )
32+ return context .get_constant (types .int32 , kargty . dpctl_float64 . value )
3433 elif ty == types .voidptr or isinstance (ty , types .CPointer ):
35- return context .get_constant (types .int32 , 10 )
34+ return context .get_constant (
35+ types .int32 , kargty .dpctl_void_ptr .value
36+ )
3637 else :
3738 raise NotImplementedError
3839 else :
0 commit comments