@@ -95,6 +95,11 @@ def __init__(
9595 capacity = config .CACHE_SIZE ,
9696 pyfunc = self .pyfunc ,
9797 )
98+ self ._kernel_bundle_cache = LRUCache (
99+ name = "KernelBundleCache" ,
100+ capacity = config .CACHE_SIZE ,
101+ pyfunc = self .pyfunc ,
102+ )
98103 else :
99104 self ._cache = NullCache ()
100105 self ._cache_hits = 0
@@ -587,6 +592,7 @@ def __call__(self, *args):
587592 # redundant. We should avoid these checks for the specialized case.
588593 exec_queue = self ._determine_kernel_launch_queue (args , argtypes )
589594 backend = exec_queue .backend
595+ device = exec_queue .sycl_device
590596
591597 if exec_queue .backend not in [
592598 dpctl .backend_type .opencl ,
@@ -626,12 +632,25 @@ def __call__(self, *args):
626632 cache = self ._cache ,
627633 )
628634
629- # create a sycl::KernelBundle
630- kernel_bundle = dpctl_prog .create_program_from_spirv (
631- exec_queue ,
632- device_driver_ir_module ,
633- " " .join (self ._create_sycl_kernel_bundle_flags ),
635+ kernel_bundle_key = build_key (
636+ tuple (argtypes ),
637+ self .pyfunc ,
638+ dpex_kernel_target .target_context .codegen (),
639+ backend = backend ,
640+ device_type = device .device_type ,
634641 )
642+
643+ kernel_bundle = self ._kernel_bundle_cache .get (kernel_bundle_key )
644+
645+ if kernel_bundle is None :
646+ # create a sycl::KernelBundle
647+ kernel_bundle = dpctl_prog .create_program_from_spirv (
648+ exec_queue ,
649+ device_driver_ir_module ,
650+ " " .join (self ._create_sycl_kernel_bundle_flags ),
651+ )
652+ self ._kernel_bundle_cache .put (kernel_bundle_key , kernel_bundle )
653+
635654 # get the sycl::kernel
636655 sycl_kernel = kernel_bundle .get_sycl_kernel (kernel_module_name )
637656
0 commit comments