1414from numba .core .types import void
1515
1616from numba_dpex import NdRange , Range , config
17- from numba_dpex .core .caching import LRUCache , NullCache , build_key
17+ from numba_dpex .core .caching import LRUCache , NullCache
1818from numba_dpex .core .descriptor import dpex_kernel_target
1919from numba_dpex .core .exceptions import (
2020 ComputeFollowsDataInferenceError ,
3434from numba_dpex .core .kernel_interface .arg_pack_unpacker import Packer
3535from numba_dpex .core .kernel_interface .spirv_kernel import SpirvKernel
3636from numba_dpex .core .types import USMNdArray
37+ from numba_dpex .core .utils import (
38+ build_key ,
39+ create_func_hash ,
40+ strip_usm_metadata ,
41+ )
3742
3843
3944def get_ordered_arg_access_types (pyfunc , access_types ):
@@ -85,6 +90,8 @@ def __init__(
8590 self ._global_range = None
8691 self ._local_range = None
8792
93+ self ._func_hash = create_func_hash (pyfunc )
94+
8895 # caching related attributes
8996 if not config .ENABLE_CACHE :
9097 self ._cache = NullCache ()
@@ -151,7 +158,7 @@ def cache(self):
151158 def cache_hits (self ):
152159 return self ._cache_hits
153160
154- def _compile_and_cache (self , argtypes , cache ):
161+ def _compile_and_cache (self , argtypes , cache , key = None ):
155162 """Helper function to compile the Python function or Numba FunctionIR
156163 object passed to a JitKernel and store it in an internal cache.
157164 """
@@ -171,11 +178,13 @@ def _compile_and_cache(self, argtypes, cache):
171178 device_driver_ir_module = kernel .device_driver_ir_module
172179 kernel_module_name = kernel .module_name
173180
174- key = build_key (
175- tuple (argtypes ),
176- self .pyfunc ,
177- kernel .target_context .codegen (),
178- )
181+ if not key :
182+ stripped_argtypes = strip_usm_metadata (argtypes )
183+ codegen_magic_tuple = kernel .target_context .codegen ().magic_tuple ()
184+ key = build_key (
185+ stripped_argtypes , codegen_magic_tuple , self ._func_hash
186+ )
187+
179188 cache .put (key , (device_driver_ir_module , kernel_module_name ))
180189
181190 return device_driver_ir_module , kernel_module_name
@@ -604,12 +613,12 @@ def __call__(self, *args):
604613 self .kernel_name , backend , JitKernel ._supported_backends
605614 )
606615
607- # load the kernel from cache
608- key = build_key (
609- tuple (argtypes ),
610- self .pyfunc ,
611- dpex_kernel_target .target_context .codegen (),
616+ # Generate key used for cache lookup
617+ stripped_argtypes = strip_usm_metadata (argtypes )
618+ codegen_magic_tuple = (
619+ dpex_kernel_target .target_context .codegen ().magic_tuple ()
612620 )
621+ key = build_key (stripped_argtypes , codegen_magic_tuple , self ._func_hash )
613622
614623 # If the JitKernel was specialized then raise exception if argtypes
615624 # do not match one of the specialized versions.
@@ -630,15 +639,11 @@ def __call__(self, *args):
630639 device_driver_ir_module ,
631640 kernel_module_name ,
632641 ) = self ._compile_and_cache (
633- argtypes = argtypes ,
634- cache = self ._cache ,
642+ argtypes = argtypes , cache = self ._cache , key = key
635643 )
636644
637645 kernel_bundle_key = build_key (
638- tuple (argtypes ),
639- self .pyfunc ,
640- dpex_kernel_target .target_context .codegen (),
641- exec_queue = exec_queue ,
646+ stripped_argtypes , codegen_magic_tuple , exec_queue , self ._func_hash
642647 )
643648
644649 artifact = self ._kernel_bundle_cache .get (kernel_bundle_key )
0 commit comments