99
1010import llvmlite .ir as llvmir
1111from llvmlite .ir .builder import IRBuilder
12+ from numba .core import cgutils , types
1213from numba .core .typing .npydecl import parse_dtype as _ty_parse_dtype
1314from numba .core .typing .npydecl import parse_shape as _ty_parse_shape
1415from numba .core .typing .templates import Signature
15- from numba .extending import intrinsic , overload
16+ from numba .extending import type_callable
1617
1718from numba_dpex .core .types import USMNdArray
1819from numba_dpex .experimental .target import DpexExpKernelTypingContext
2324)
2425from numba_dpex .utils import address_space as AddressSpace
2526
26- from .. target import DPEX_KERNEL_EXP_TARGET_NAME
27+ from ._registry import lower
2728
2829
29- @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
30- def _intrinsic_private_array_ctor (
31- ty_context , ty_shape , ty_dtype # pylint: disable=unused-argument
32- ):
33- require_literal (ty_shape )
34-
35- ty_array = USMNdArray (
36- dtype = _ty_parse_dtype (ty_dtype ),
37- ndim = _ty_parse_shape (ty_shape ),
38- layout = "C" ,
39- addrspace = AddressSpace .PRIVATE ,
40- )
41-
42- sig = ty_array (ty_shape , ty_dtype )
43-
44- def codegen (
45- context : DpexExpKernelTypingContext ,
46- builder : IRBuilder ,
47- sig : Signature ,
48- args : list [llvmir .Value ],
49- ):
50- shape = args [0 ]
51- ty_shape = sig .args [0 ]
52- ty_array = sig .return_type
53-
54- ary = make_spirv_generic_array_on_stack (
55- context , builder , ty_array , ty_shape , shape
56- )
57- return ary ._getvalue () # pylint: disable=protected-access
58-
59- return (
60- sig ,
61- codegen ,
62- )
63-
64-
65- @overload (
66- PrivateArray ,
67- prefer_literal = True ,
68- target = DPEX_KERNEL_EXP_TARGET_NAME ,
69- )
70- def ol_private_array_ctor (
71- shape ,
72- dtype ,
73- ):
74- """Overload of the constructor for the class
30+ @type_callable (PrivateArray )
31+ def type_interval (context ): # pylint: disable=unused-argument
32+ """Sets type of the constructor for the class
7533 class:`numba_dpex.kernel_api.PrivateArray`.
7634
7735 Raises:
@@ -81,11 +39,48 @@ def ol_private_array_ctor(
8139 type.
8240 """
8341
84- def ol_private_array_ctor_impl (
85- shape ,
86- dtype ,
87- ):
88- # pylint: disable=no-value-for-parameter
89- return _intrinsic_private_array_ctor (shape , dtype )
42+ def typer (shape , dtype , fill_zeros = types .BooleanLiteral (False )):
43+ require_literal (shape )
44+ require_literal (fill_zeros )
45+
46+ return USMNdArray (
47+ dtype = _ty_parse_dtype (dtype ),
48+ ndim = _ty_parse_shape (shape ),
49+ layout = "C" ,
50+ addrspace = AddressSpace .PRIVATE ,
51+ )
52+
53+ return typer
54+
55+
56+ @lower (PrivateArray , types .IntegerLiteral , types .Any , types .BooleanLiteral )
57+ @lower (PrivateArray , types .Tuple , types .Any , types .BooleanLiteral )
58+ @lower (PrivateArray , types .UniTuple , types .Any , types .BooleanLiteral )
59+ @lower (PrivateArray , types .IntegerLiteral , types .Any )
60+ @lower (PrivateArray , types .Tuple , types .Any )
61+ @lower (PrivateArray , types .UniTuple , types .Any )
62+ def dpex_private_array_lower (
63+ context : DpexExpKernelTypingContext ,
64+ builder : IRBuilder ,
65+ sig : Signature ,
66+ args : list [llvmir .Value ],
67+ ):
68+ """Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`"""
69+ shape = args [0 ]
70+ ty_shape = sig .args [0 ]
71+ if len (sig .args ) == 3 :
72+ fill_zeros = sig .args [- 1 ].literal_value
73+ else :
74+ fill_zeros = False
75+ ty_array = sig .return_type
76+
77+ ary = make_spirv_generic_array_on_stack (
78+ context , builder , ty_array , ty_shape , shape
79+ )
80+
81+ if fill_zeros :
82+ cgutils .memset (
83+ builder , ary .data , builder .mul (ary .itemsize , ary .nitems ), 0
84+ )
9085
91- return ol_private_array_ctor_impl
86+ return ary . _getvalue () # pylint: disable=protected-access
0 commit comments