Skip to content

Commit af72372

Browse files
chudur-budurDiptorup Deb
authored andcommitted
Init clean up and monkeypatch fix
1 parent a308784 commit af72372

File tree

10 files changed

+455
-311
lines changed

10 files changed

+455
-311
lines changed

numba_dpex/__init__.py

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,130 @@
55
"""
66
The numba-dpex extension module adds data-parallel offload support to Numba.
77
"""
8+
import glob
9+
import logging
10+
import os
11+
import platform as plt
812

9-
import numba_dpex.core.dpjit_dispatcher
10-
import numba_dpex.core.offload_dispatcher
13+
import dpctl
14+
import llvmlite.binding as ll
15+
import numba
16+
from numba.core import ir_utils
17+
from numba.np import arrayobj
18+
from numba.np.ufunc import array_exprs
19+
from numba.np.ufunc.decorators import Vectorize
20+
21+
from numba_dpex._patches import _empty_nd_impl, _is_ufunc, _mk_alloc
22+
from numba_dpex.vectorizers import Vectorize as DpexVectorize
23+
24+
# Monkey patches
25+
array_exprs._is_ufunc = _is_ufunc
26+
ir_utils.mk_alloc = _mk_alloc
27+
arrayobj._empty_nd_impl = _empty_nd_impl
28+
29+
30+
def load_dpctl_sycl_interface():
31+
"""Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl.
32+
The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions
33+
that are directly invoked from the LLVM modules generated by numba_dpex.
34+
We load the library once at the time of initialization using llvmlite's
35+
load_library_permanently function.
36+
Raises:
37+
ImportError: If the ``DPCTLSyclInterface`` library could not be loaded.
38+
"""
39+
40+
platform = plt.system()
41+
if platform == "Windows":
42+
paths = glob.glob(
43+
os.path.join(
44+
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll"
45+
)
46+
)
47+
else:
48+
paths = glob.glob(
49+
os.path.join(
50+
os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0"
51+
)
52+
)
53+
54+
if len(paths) == 1:
55+
ll.load_library_permanently(paths[0])
56+
else:
57+
raise ImportError
58+
59+
Vectorize.target_registry.ondemand["dpex"] = lambda: DpexVectorize
60+
61+
62+
numba_version = tuple(map(int, numba.__version__.split(".")[:3]))
63+
if numba_version < (0, 56, 4):
64+
logging.warning(
65+
"numba_dpex needs numba 0.56.4, using "
66+
f"numba={numba_version} may cause unexpected behavior"
67+
)
68+
69+
70+
dpctl_version = tuple(map(int, dpctl.__version__.split(".")[:2]))
71+
if dpctl_version < (0, 14):
72+
logging.warning(
73+
"numba_dpex needs dpctl 0.14 or greater, using "
74+
f"dpctl={dpctl_version} may cause unexpected behavior"
75+
)
76+
77+
78+
import numba_dpex.core.dpjit_dispatcher # noqa E402
79+
import numba_dpex.core.offload_dispatcher # noqa E402
1180

1281
# Initialize the _dpexrt_python extension
13-
import numba_dpex.core.runtime
14-
import numba_dpex.core.targets.dpjit_target
82+
import numba_dpex.core.runtime # noqa E402
83+
import numba_dpex.core.targets.dpjit_target # noqa E402
1584

1685
# Re-export types itself
17-
import numba_dpex.core.types as types
18-
from numba_dpex.core.kernel_interface.indexers import NdRange, Range
86+
import numba_dpex.core.types as types # noqa E402
87+
from numba_dpex import config # noqa E402
88+
from numba_dpex.core.kernel_interface.indexers import ( # noqa E402
89+
NdRange,
90+
Range,
91+
)
1992

2093
# Re-export all type names
21-
from numba_dpex.core.types import *
22-
from numba_dpex.retarget import offload_to_sycl_device
23-
24-
from . import config
94+
from numba_dpex.core.types import * # noqa E402
95+
from numba_dpex.retarget import offload_to_sycl_device # noqa E402
2596

2697
if config.HAS_NON_HOST_DEVICE:
27-
from .device_init import *
98+
# Re export
99+
from .core.targets import dpjit_target, kernel_target
100+
from .decorators import dpjit, func, kernel
101+
102+
# We are importing dpnp stub module to make Numba recognize the
103+
# module when we rename Numpy functions.
104+
from .dpnp_iface.stubs import dpnp
105+
from .ocl.stubs import (
106+
GLOBAL_MEM_FENCE,
107+
LOCAL_MEM_FENCE,
108+
atomic,
109+
barrier,
110+
get_global_id,
111+
get_global_size,
112+
get_group_id,
113+
get_local_id,
114+
get_local_size,
115+
get_num_groups,
116+
get_work_dim,
117+
local,
118+
mem_fence,
119+
private,
120+
sub_group_barrier,
121+
)
122+
123+
DEFAULT_LOCAL_SIZE = []
124+
load_dpctl_sycl_interface()
125+
del load_dpctl_sycl_interface
28126
else:
29127
raise ImportError("No non-host SYCL device found to execute kernels.")
30128

31-
32-
from ._version import get_versions
129+
from numba_dpex._version import get_versions # noqa E402
33130

34131
__version__ = get_versions()["version"]
35132
del get_versions
36133

37-
__all__ = ["offload_to_sycl_device"] + types.__all__ + ["Range", "NdRange"]
134+
__all__ = types.__all__ + ["offload_to_sycl_device"] + ["Range", "NdRange"]

0 commit comments

Comments
 (0)