|
5 | 5 | """ |
6 | 6 | The numba-dpex extension module adds data-parallel offload support to Numba. |
7 | 7 | """ |
| 8 | +import glob |
| 9 | +import logging |
| 10 | +import os |
| 11 | +import platform as plt |
8 | 12 |
|
9 | | -import numba_dpex.core.dpjit_dispatcher |
10 | | -import numba_dpex.core.offload_dispatcher |
| 13 | +import dpctl |
| 14 | +import llvmlite.binding as ll |
| 15 | +import numba |
| 16 | +from numba.core import ir_utils |
| 17 | +from numba.np import arrayobj |
| 18 | +from numba.np.ufunc import array_exprs |
| 19 | +from numba.np.ufunc.decorators import Vectorize |
| 20 | + |
| 21 | +from numba_dpex._patches import _empty_nd_impl, _is_ufunc, _mk_alloc |
| 22 | +from numba_dpex.vectorizers import Vectorize as DpexVectorize |
| 23 | + |
| 24 | +# Monkey patches |
| 25 | +array_exprs._is_ufunc = _is_ufunc |
| 26 | +ir_utils.mk_alloc = _mk_alloc |
| 27 | +arrayobj._empty_nd_impl = _empty_nd_impl |
| 28 | + |
| 29 | + |
| 30 | +def load_dpctl_sycl_interface(): |
| 31 | + """Permanently loads the ``DPCTLSyclInterface`` library provided by dpctl. |
| 32 | + The ``DPCTLSyclInterface`` library provides C wrappers over SYCL functions |
| 33 | + that are directly invoked from the LLVM modules generated by numba_dpex. |
| 34 | + We load the library once at the time of initialization using llvmlite's |
| 35 | + load_library_permanently function. |
| 36 | + Raises: |
| 37 | + ImportError: If the ``DPCTLSyclInterface`` library could not be loaded. |
| 38 | + """ |
| 39 | + |
| 40 | + platform = plt.system() |
| 41 | + if platform == "Windows": |
| 42 | + paths = glob.glob( |
| 43 | + os.path.join( |
| 44 | + os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.dll" |
| 45 | + ) |
| 46 | + ) |
| 47 | + else: |
| 48 | + paths = glob.glob( |
| 49 | + os.path.join( |
| 50 | + os.path.dirname(dpctl.__file__), "*DPCTLSyclInterface.so.0" |
| 51 | + ) |
| 52 | + ) |
| 53 | + |
| 54 | + if len(paths) == 1: |
| 55 | + ll.load_library_permanently(paths[0]) |
| 56 | + else: |
| 57 | + raise ImportError |
| 58 | + |
| 59 | + Vectorize.target_registry.ondemand["dpex"] = lambda: DpexVectorize |
| 60 | + |
| 61 | + |
| 62 | +numba_version = tuple(map(int, numba.__version__.split(".")[:3])) |
| 63 | +if numba_version < (0, 56, 4): |
| 64 | + logging.warning( |
| 65 | + "numba_dpex needs numba 0.56.4, using " |
| 66 | + f"numba={numba_version} may cause unexpected behavior" |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +dpctl_version = tuple(map(int, dpctl.__version__.split(".")[:2])) |
| 71 | +if dpctl_version < (0, 14): |
| 72 | + logging.warning( |
| 73 | + "numba_dpex needs dpctl 0.14 or greater, using " |
| 74 | + f"dpctl={dpctl_version} may cause unexpected behavior" |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +import numba_dpex.core.dpjit_dispatcher # noqa E402 |
| 79 | +import numba_dpex.core.offload_dispatcher # noqa E402 |
11 | 80 |
|
12 | 81 | # Initialize the _dpexrt_python extension |
13 | | -import numba_dpex.core.runtime |
14 | | -import numba_dpex.core.targets.dpjit_target |
| 82 | +import numba_dpex.core.runtime # noqa E402 |
| 83 | +import numba_dpex.core.targets.dpjit_target # noqa E402 |
15 | 84 |
|
16 | 85 | # Re-export types itself |
17 | | -import numba_dpex.core.types as types |
18 | | -from numba_dpex.core.kernel_interface.indexers import NdRange, Range |
| 86 | +import numba_dpex.core.types as types # noqa E402 |
| 87 | +from numba_dpex import config # noqa E402 |
| 88 | +from numba_dpex.core.kernel_interface.indexers import ( # noqa E402 |
| 89 | + NdRange, |
| 90 | + Range, |
| 91 | +) |
19 | 92 |
|
20 | 93 | # Re-export all type names |
21 | | -from numba_dpex.core.types import * |
22 | | -from numba_dpex.retarget import offload_to_sycl_device |
23 | | - |
24 | | -from . import config |
| 94 | +from numba_dpex.core.types import * # noqa E402 |
| 95 | +from numba_dpex.retarget import offload_to_sycl_device # noqa E402 |
25 | 96 |
|
26 | 97 | if config.HAS_NON_HOST_DEVICE: |
27 | | - from .device_init import * |
| 98 | + # Re export |
| 99 | + from .core.targets import dpjit_target, kernel_target |
| 100 | + from .decorators import dpjit, func, kernel |
| 101 | + |
| 102 | + # We are importing dpnp stub module to make Numba recognize the |
| 103 | + # module when we rename Numpy functions. |
| 104 | + from .dpnp_iface.stubs import dpnp |
| 105 | + from .ocl.stubs import ( |
| 106 | + GLOBAL_MEM_FENCE, |
| 107 | + LOCAL_MEM_FENCE, |
| 108 | + atomic, |
| 109 | + barrier, |
| 110 | + get_global_id, |
| 111 | + get_global_size, |
| 112 | + get_group_id, |
| 113 | + get_local_id, |
| 114 | + get_local_size, |
| 115 | + get_num_groups, |
| 116 | + get_work_dim, |
| 117 | + local, |
| 118 | + mem_fence, |
| 119 | + private, |
| 120 | + sub_group_barrier, |
| 121 | + ) |
| 122 | + |
| 123 | + DEFAULT_LOCAL_SIZE = [] |
| 124 | + load_dpctl_sycl_interface() |
| 125 | + del load_dpctl_sycl_interface |
28 | 126 | else: |
29 | 127 | raise ImportError("No non-host SYCL device found to execute kernels.") |
30 | 128 |
|
31 | | - |
32 | | -from ._version import get_versions |
| 129 | +from numba_dpex._version import get_versions # noqa E402 |
33 | 130 |
|
34 | 131 | __version__ = get_versions()["version"] |
35 | 132 | del get_versions |
36 | 133 |
|
37 | | -__all__ = ["offload_to_sycl_device"] + types.__all__ + ["Range", "NdRange"] |
| 134 | +__all__ = types.__all__ + ["offload_to_sycl_device"] + ["Range", "NdRange"] |
0 commit comments