Skip to content

Commit ad9955c

Browse files
author
Diptorup Deb
committed
Fix KernelHasReturnValueError inside KernelDispatcher.
1 parent 68b1f39 commit ad9955c

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

numba_dpex/kernel_api_impl/spirv/dispatcher.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,6 @@ def cb_llvm(dur):
415415
except ExecutionQueueInferenceError as eqie:
416416
raise eqie
417417

418-
# A function being compiled in the KERNEL compilation mode
419-
# cannot have a non-void return value
420-
if return_type and return_type != void:
421-
raise KernelHasReturnValueError(
422-
kernel_name=None, return_type=return_type, sig=sig
423-
)
424-
425418
# Don't recompile if signature already exists
426419
existing = self.overloads.get(tuple(args))
427420
if existing is not None:
@@ -444,6 +437,16 @@ def cb_llvm(dur):
444437
kcres: _SPIRVKernelCompileResult = compiler.compile(
445438
args, return_type
446439
)
440+
if (
441+
self.targetoptions["_compilation_mode"]
442+
== CompilationMode.KERNEL
443+
and kcres.signature.return_type is not None
444+
and kcres.signature.return_type != types.void
445+
):
446+
raise KernelHasReturnValueError(
447+
kernel_name=self.py_func.__name__,
448+
return_type=kcres.signature.return_type,
449+
)
447450
except errors.ForceLiteralArg as err:
448451

449452
def folded(args, kws):
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: 2020 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import dpnp
6+
import pytest
7+
from numba.core.errors import TypingError
8+
9+
import numba_dpex.experimental as dpex
10+
from numba_dpex import int32, usm_ndarray
11+
from numba_dpex.core.exceptions import KernelHasReturnValueError
12+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
13+
14+
i32arrty = usm_ndarray(ndim=1, dtype=int32, layout="C")
15+
item_ty = ItemType(ndim=1)
16+
17+
18+
def f(item, a):
19+
return a
20+
21+
22+
list_of_sig = [
23+
None,
24+
(i32arrty(item_ty, i32arrty)),
25+
]
26+
27+
28+
@pytest.fixture(params=list_of_sig)
29+
def sig(request):
30+
return request.param
31+
32+
33+
def test_return(sig):
34+
a = dpnp.arange(1024, dtype=dpnp.int32)
35+
36+
with pytest.raises((TypingError, KernelHasReturnValueError)) as excinfo:
37+
kernel_fn = dpex.kernel(sig)(f)
38+
dpex.call_kernel(kernel_fn, dpex.Range(a.size), a)
39+
40+
if isinstance(excinfo.type, TypingError):
41+
assert "KernelHasReturnValueError" in excinfo.value.args[0]

0 commit comments

Comments
 (0)