Skip to content

Commit ae994cd

Browse files
author
Diptorup Deb
authored
Merge pull request #912 from adarshyoga/kmeans_perf_improvements
Dispatcher/caching rewrite to address performance regression
2 parents a3f5604 + 1056a58 commit ae994cd

File tree

5 files changed

+117
-91
lines changed

5 files changed

+117
-91
lines changed

numba_dpex/core/caching.py

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,77 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import hashlib
65
import sys
76
from abc import ABCMeta, abstractmethod
87

98
from numba.core.caching import CacheImpl, IndexDataCacheFile
10-
from numba.core.serialize import dumps
119

1210
from numba_dpex import config
13-
from numba_dpex.core.types import USMNdArray
14-
15-
16-
def build_key(
17-
argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None
18-
):
19-
"""Constructs a key from python function, context, backend, the device
20-
type and execution queue.
21-
22-
Compute index key for the given argument types and codegen. It includes a
23-
description of the OS, target architecture and hashes of the bytecode for
24-
the function and, if the function has a __closure__, a hash of the
25-
cell_contents.type
26-
27-
Args:
28-
argtypes : A tuple of numba types corresponding to the arguments to the
29-
compiled function.
30-
pyfunc : The Python function that is to be compiled and cached.
31-
codegen (numba.core.codegen.Codegen):
32-
The codegen object found from the target context.
33-
backend (enum, optional): A 'backend_type' enum.
34-
Defaults to None.
35-
device_type (enum, optional): A 'device_type' enum.
36-
Defaults to None.
37-
exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object.
38-
39-
Returns:
40-
tuple: A tuple of return type, argtpes, magic_tuple of codegen
41-
and another tuple of hashcodes from bytecode and cell_contents.
42-
"""
43-
44-
codebytes = pyfunc.__code__.co_code
45-
if pyfunc.__closure__ is not None:
46-
try:
47-
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
48-
# Note: cloudpickle serializes a function differently depending
49-
# on how the process is launched; e.g. multiprocessing.Process
50-
cvarbytes = dumps(cvars)
51-
except:
52-
cvarbytes = b"" # a temporary solution for function template
53-
else:
54-
cvarbytes = b""
55-
56-
argtylist = list(argtypes)
57-
for i, argty in enumerate(argtylist):
58-
if isinstance(argty, USMNdArray):
59-
# Convert the USMNdArray to an abridged type that disregards the
60-
# usm_type, device, queue, address space attributes.
61-
argtylist[i] = (argty.ndim, argty.dtype, argty.layout)
62-
63-
argtypes = tuple(argtylist)
64-
65-
return (
66-
argtypes,
67-
codegen.magic_tuple(),
68-
backend,
69-
device_type,
70-
exec_queue,
71-
(
72-
hashlib.sha256(codebytes).hexdigest(),
73-
hashlib.sha256(cvarbytes).hexdigest(),
74-
),
75-
)
7611

7712

7813
class _CacheImpl(CacheImpl):
@@ -475,8 +410,13 @@ def put(self, key, value):
475410
self._name, len(self._lookup), str(key)
476411
)
477412
)
478-
self._lookup[key].value = value
479-
self.get(key)
413+
node = self._lookup[key]
414+
node.value = value
415+
416+
if node is not self._tail:
417+
self._unlink_node(node)
418+
self._append_tail(node)
419+
480420
return
481421

482422
if key in self._evicted:

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numba.core.types import void
1515

1616
from numba_dpex import NdRange, Range, config
17-
from numba_dpex.core.caching import LRUCache, NullCache, build_key
17+
from numba_dpex.core.caching import LRUCache, NullCache
1818
from numba_dpex.core.descriptor import dpex_kernel_target
1919
from numba_dpex.core.exceptions import (
2020
ComputeFollowsDataInferenceError,
@@ -34,6 +34,11 @@
3434
from numba_dpex.core.kernel_interface.arg_pack_unpacker import Packer
3535
from numba_dpex.core.kernel_interface.spirv_kernel import SpirvKernel
3636
from numba_dpex.core.types import USMNdArray
37+
from numba_dpex.core.utils import (
38+
build_key,
39+
create_func_hash,
40+
strip_usm_metadata,
41+
)
3742

3843

3944
def get_ordered_arg_access_types(pyfunc, access_types):
@@ -85,6 +90,8 @@ def __init__(
8590
self._global_range = None
8691
self._local_range = None
8792

93+
self._func_hash = create_func_hash(pyfunc)
94+
8895
# caching related attributes
8996
if not config.ENABLE_CACHE:
9097
self._cache = NullCache()
@@ -151,7 +158,7 @@ def cache(self):
151158
def cache_hits(self):
152159
return self._cache_hits
153160

154-
def _compile_and_cache(self, argtypes, cache):
161+
def _compile_and_cache(self, argtypes, cache, key=None):
155162
"""Helper function to compile the Python function or Numba FunctionIR
156163
object passed to a JitKernel and store it in an internal cache.
157164
"""
@@ -171,11 +178,13 @@ def _compile_and_cache(self, argtypes, cache):
171178
device_driver_ir_module = kernel.device_driver_ir_module
172179
kernel_module_name = kernel.module_name
173180

174-
key = build_key(
175-
tuple(argtypes),
176-
self.pyfunc,
177-
kernel.target_context.codegen(),
178-
)
181+
if not key:
182+
stripped_argtypes = strip_usm_metadata(argtypes)
183+
codegen_magic_tuple = kernel.target_context.codegen().magic_tuple()
184+
key = build_key(
185+
stripped_argtypes, codegen_magic_tuple, self._func_hash
186+
)
187+
179188
cache.put(key, (device_driver_ir_module, kernel_module_name))
180189

181190
return device_driver_ir_module, kernel_module_name
@@ -604,12 +613,12 @@ def __call__(self, *args):
604613
self.kernel_name, backend, JitKernel._supported_backends
605614
)
606615

607-
# load the kernel from cache
608-
key = build_key(
609-
tuple(argtypes),
610-
self.pyfunc,
611-
dpex_kernel_target.target_context.codegen(),
616+
# Generate key used for cache lookup
617+
stripped_argtypes = strip_usm_metadata(argtypes)
618+
codegen_magic_tuple = (
619+
dpex_kernel_target.target_context.codegen().magic_tuple()
612620
)
621+
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)
613622

614623
# If the JitKernel was specialized then raise exception if argtypes
615624
# do not match one of the specialized versions.
@@ -630,15 +639,11 @@ def __call__(self, *args):
630639
device_driver_ir_module,
631640
kernel_module_name,
632641
) = self._compile_and_cache(
633-
argtypes=argtypes,
634-
cache=self._cache,
642+
argtypes=argtypes, cache=self._cache, key=key
635643
)
636644

637645
kernel_bundle_key = build_key(
638-
tuple(argtypes),
639-
self.pyfunc,
640-
dpex_kernel_target.target_context.codegen(),
641-
exec_queue=exec_queue,
646+
stripped_argtypes, codegen_magic_tuple, exec_queue, self._func_hash
642647
)
643648

644649
artifact = self._kernel_bundle_cache.get(kernel_bundle_key)

numba_dpex/core/kernel_interface/func.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
"""_summary_
66
"""
77

8-
98
from numba.core import sigutils, types
109
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
1110

1211
from numba_dpex import config
13-
from numba_dpex.core.caching import LRUCache, NullCache, build_key
12+
from numba_dpex.core.caching import LRUCache, NullCache
1413
from numba_dpex.core.compiler import compile_with_dpex
1514
from numba_dpex.core.descriptor import dpex_kernel_target
15+
from numba_dpex.core.utils import (
16+
build_key,
17+
create_func_hash,
18+
strip_usm_metadata,
19+
)
1620
from numba_dpex.utils import npytypes_array_to_dpex_array
1721

1822

@@ -91,6 +95,8 @@ def __init__(self, pyfunc, debug=False, enable_cache=True):
9195
self._debug = debug
9296
self._enable_cache = enable_cache
9397

98+
self._func_hash = create_func_hash(pyfunc)
99+
94100
if not config.ENABLE_CACHE:
95101
self._cache = NullCache()
96102
elif self._enable_cache:
@@ -132,11 +138,14 @@ def compile(self, args):
132138
dpex_kernel_target.typing_context.resolve_argument_type(arg)
133139
for arg in args
134140
]
135-
key = build_key(
136-
tuple(argtypes),
137-
self._pyfunc,
138-
dpex_kernel_target.target_context.codegen(),
141+
142+
# Generate key used for cache lookup
143+
stripped_argtypes = strip_usm_metadata(argtypes)
144+
codegen_magic_tuple = (
145+
dpex_kernel_target.target_context.codegen().magic_tuple()
139146
)
147+
key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)
148+
140149
cres = self._cache.get(key)
141150
if cres is None:
142151
self._cache_hits += 1

numba_dpex/core/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from .caching_utils import build_key, create_func_hash, strip_usm_metadata
56
from .suai_helper import SyclUSMArrayInterface, get_info_from_suai
67

78
__all__ = [
89
"get_info_from_suai",
910
"SyclUSMArrayInterface",
11+
"create_func_hash",
12+
"strip_usm_metadata",
13+
"build_key",
1014
]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-FileCopyrightText: 2020 - 2022 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import hashlib
6+
7+
from numba.core.serialize import dumps
8+
9+
from numba_dpex.core.types import USMNdArray
10+
11+
12+
def build_key(*args):
13+
"""Constructs key from variable list of args
14+
15+
Args:
16+
*args: List of components to construct key
17+
Return:
18+
Tuple of args
19+
"""
20+
return tuple(args)
21+
22+
23+
def create_func_hash(pyfunc):
24+
"""Creates a tuple of sha256 hashes out of code and
25+
variable bytes extracted from the compiled funtion.
26+
27+
Args:
28+
pyfunc: Python function object
29+
Return:
30+
Tuple of hashes of code and variable bytes
31+
"""
32+
codebytes = pyfunc.__code__.co_code
33+
if pyfunc.__closure__ is not None:
34+
try:
35+
cvars = tuple([x.cell_contents for x in pyfunc.__closure__])
36+
# Note: cloudpickle serializes a function differently depending
37+
# on how the process is launched; e.g. multiprocessing.Process
38+
cvarbytes = dumps(cvars)
39+
except:
40+
cvarbytes = b"" # a temporary solution for function template
41+
else:
42+
cvarbytes = b""
43+
44+
return (
45+
hashlib.sha256(codebytes).hexdigest(),
46+
hashlib.sha256(cvarbytes).hexdigest(),
47+
)
48+
49+
50+
def strip_usm_metadata(argtypes):
51+
"""Convert the USMNdArray to an abridged type that disregards the
52+
usm_type, device, queue, address space attributes.
53+
54+
Args:
55+
argtypes: List of types
56+
57+
Return:
58+
Tuple of types after removing USM metadata from USMNdArray type
59+
"""
60+
61+
stripped_argtypes = []
62+
for argty in argtypes:
63+
if isinstance(argty, USMNdArray):
64+
stripped_argtypes.append((argty.ndim, argty.dtype, argty.layout))
65+
else:
66+
stripped_argtypes.append(argty)
67+
68+
return tuple(stripped_argtypes)

0 commit comments

Comments
 (0)