File tree Expand file tree Collapse file tree 2 files changed +51
-7
lines changed Expand file tree Collapse file tree 2 files changed +51
-7
lines changed Original file line number Diff line number Diff line change @@ -415,13 +415,6 @@ def cb_llvm(dur):
415415 except ExecutionQueueInferenceError as eqie :
416416 raise eqie
417417
418- # A function being compiled in the KERNEL compilation mode
419- # cannot have a non-void return value
420- if return_type and return_type != void :
421- raise KernelHasReturnValueError (
422- kernel_name = None , return_type = return_type , sig = sig
423- )
424-
425418 # Don't recompile if signature already exists
426419 existing = self .overloads .get (tuple (args ))
427420 if existing is not None :
@@ -444,6 +437,16 @@ def cb_llvm(dur):
444437 kcres : _SPIRVKernelCompileResult = compiler .compile (
445438 args , return_type
446439 )
440+ if (
441+ self .targetoptions ["_compilation_mode" ]
442+ == CompilationMode .KERNEL
443+ and kcres .signature .return_type is not None
444+ and kcres .signature .return_type != types .void
445+ ):
446+ raise KernelHasReturnValueError (
447+ kernel_name = self .py_func .__name__ ,
448+ return_type = kcres .signature .return_type ,
449+ )
447450 except errors .ForceLiteralArg as err :
448451
449452 def folded (args , kws ):
Original file line number Diff line number Diff line change 1+ # SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+ #
3+ # SPDX-License-Identifier: Apache-2.0
4+
5+ import dpnp
6+ import pytest
7+ from numba .core .errors import TypingError
8+
9+ import numba_dpex .experimental as dpex
10+ from numba_dpex import int32 , usm_ndarray
11+ from numba_dpex .core .exceptions import KernelHasReturnValueError
12+ from numba_dpex .core .types .kernel_api .index_space_ids import ItemType
13+
14+ i32arrty = usm_ndarray (ndim = 1 , dtype = int32 , layout = "C" )
15+ item_ty = ItemType (ndim = 1 )
16+
17+
18+ def f (item , a ):
19+ return a
20+
21+
22+ list_of_sig = [
23+ None ,
24+ (i32arrty (item_ty , i32arrty )),
25+ ]
26+
27+
28+ @pytest .fixture (params = list_of_sig )
29+ def sig (request ):
30+ return request .param
31+
32+
33+ def test_return (sig ):
34+ a = dpnp .arange (1024 , dtype = dpnp .int32 )
35+
36+ with pytest .raises ((TypingError , KernelHasReturnValueError )) as excinfo :
37+ kernel_fn = dpex .kernel (sig )(f )
38+ dpex .call_kernel (kernel_fn , dpex .Range (a .size ), a )
39+
40+ if isinstance (excinfo .type , TypingError ):
41+ assert "KernelHasReturnValueError" in excinfo .value .args [0 ]
You can’t perform that action at this time.
0 commit comments