Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions numba_dpex/core/passes/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
new_error_context,
)
from numba.core.ir_utils import remove_dels
from numba.core.typed_passes import NativeLowering
from numba.parfors.parfor import Parfor
from numba.parfors.parfor import ParforPass as _parfor_ParforPass
from numba.parfors.parfor import PreParforPass as _parfor_PreParforPass
Expand Down Expand Up @@ -387,3 +388,30 @@ def run_pass(self, state):
else:
raise RuntimeError("Diagnostics failed.")
return True


@register_pass(mutates_CFG=False, analysis_only=True)
class QualNameDisambiguationLowering(NativeLowering):
"""Qualified name disambiguation lowering pass

If there are multiple @func decorated functions exist inside
another @func decorated block, the numba compiler machinery
creates same qualified names for different compiled function.
Therefore, we utilize `unique_name` to resolve the ambiguity.

Args:
NativeLowering (CompilerPass): Superclass from which this
class has been inherited.

Returns:
bool: True if `run_pass()` of the superclass is successful.
"""

_name = "qual-name-disambiguation-lowering"

def run_pass(self, state):
qual_name = state.func_id.func_qualname
state.func_id.func_qualname = state.func_id.unique_name
ret = NativeLowering.run_pass(self, state)
state.func_id.func_qualname = qual_name
return ret
10 changes: 8 additions & 2 deletions numba_dpex/core/pipelines/kernel_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from numba.core.typed_passes import (
AnnotateTypes,
IRLegalization,
NativeLowering,
NopythonRewrites,
NoPythonSupportedFeatureValidation,
NopythonTypeInference,
Expand All @@ -34,6 +33,7 @@
from numba_dpex.core.passes.passes import (
ConstantSizeStaticLocalMemoryPass,
NoPythonBackend,
QualNameDisambiguationLowering,
)


Expand Down Expand Up @@ -139,7 +139,13 @@ def define_nopython_lowering_pipeline(state, name="dpex_kernel_lowering"):
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")

# lower
pm.add_pass(NativeLowering, "native lowering")
# NativeLowering has some issue with freevar ambiguity,
# therefore, we are using QualNameDisambiguationLowering instead
# numba-dpex github issue: https://github.com/IntelPython/numba-dpex/issues/898
pm.add_pass(
QualNameDisambiguationLowering,
"numba_dpex qualified name disambiguation",
)
pm.add_pass(NoPythonBackend, "nopython mode backend")

pm.finalize()
Expand Down
119 changes: 119 additions & 0 deletions numba_dpex/tests/kernel_tests/test_func_qualname_disambiguation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import dpctl
import dpctl.tensor as dpt
import numpy as np
import pytest

import numba_dpex as ndpx
from numba_dpex.tests._helper import filter_strings


def make_write_values_kernel(n_rows):
"""Uppermost kernel to set 1s in a certain way.
The uppermost kernel function invokes two levels
of inner functions to set 1s in an empty matrix
in a certain way.

Args:
n_rows (int): Number of rows to iterate.

Returns:
numba_dpex.core.kernel_interface.dispatcher.JitKernel:
A JitKernel object that encapsulates a @kernel
decorated numba_dpex compiled kernel object.
"""
write_values = make_write_values_kernel_func()

@ndpx.kernel
def write_values_kernel(array_in):
for row_idx in range(n_rows):
is_even = (row_idx % 2) == 0
write_values(array_in, row_idx, is_even)

return write_values_kernel[ndpx.NdRange(ndpx.Range(1), ndpx.Range(1))]


def make_write_values_kernel_func():
"""An upper function to set 1 or 3 ones. A function to set
one or three 1s. If the row index is even it will set three 1s,
otherwise one 1. It uses the inner function to do this.

Returns:
numba_dpex.core.kernel_interface.func.DpexFunctionTemplate:
A DpexFunctionTemplate that encapsulates a @func decorated
numba_dpex compiled function object.
"""
write_when_odd = make_write_values_kernel_func_inner(1)
write_when_even = make_write_values_kernel_func_inner(3)

@ndpx.func
def write_values(array_in, row_idx, is_even):
if is_even:
write_when_even(array_in, row_idx)
else:
write_when_odd(array_in, row_idx)

return write_values


def make_write_values_kernel_func_inner(n_cols):
"""Inner function to set 1s. An inner function to set 1s in
n_cols number of columns.

Args:
n_cols (int): Number of columns to be set to 1.

Returns:
numba_dpex.core.kernel_interface.func.DpexFunctionTemplate:
A DpexFunctionTemplate that encapsulates a @func decorated
numba_dpex compiled function object.
"""

@ndpx.func
def write_values_inner(array_in, row_idx):
for idx in range(n_cols):
array_in[row_idx, idx] = 1

return write_values_inner


@pytest.mark.parametrize("offload_device", filter_strings)
def test_qualname_basic(offload_device):
"""A basic test function to test
qualified name disambiguation.
"""
ans = np.zeros((10, 10), dtype=np.int64)
for i in range(ans.shape[0]):
if i % 2 == 0:
ans[i, 0:3] = 1
else:
ans[i, 0] = 1

a = np.zeros((10, 10), dtype=dpt.int64)

device = dpctl.SyclDevice(offload_device)
queue = dpctl.SyclQueue(device)

da = dpt.usm_ndarray(
a.shape,
dtype=a.dtype,
buffer="device",
buffer_ctor_kwargs={"queue": queue},
)
da.usm_data.copy_from_host(a.reshape((-1)).view("|u1"))

kernel = make_write_values_kernel(10)
kernel(da)

result = np.zeros_like(a)
da.usm_data.copy_to_host(result.reshape((-1)).view("|u1"))

print(ans)
print(result)

assert np.array_equal(result, ans)


if __name__ == "__main__":
test_qualname_basic("level_zero:gpu:0")
test_qualname_basic("opencl:gpu:0")
test_qualname_basic("opencl:cpu:0")
8 changes: 4 additions & 4 deletions numba_dpex/tests/test_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def data_parallel_sum(a, b, c):
c[i] = func_sum(a[i], b[i])

ir_tags = [
r'\!DISubprogram\(name: ".*func_sum"',
r'\!DISubprogram\(name: ".*data_parallel_sum"',
r'\!DISubprogram\(name: ".*func_sum\$?\d*"',
r'\!DISubprogram\(name: ".*data_parallel_sum\$?\d*"',
]

sig = (f32arrty, f32arrty, f32arrty)
Expand All @@ -154,8 +154,8 @@ def data_parallel_sum(a, b, c):
c[i] = func_sum(a[i], b[i])

ir_tags = [
r'\!DISubprogram\(name: ".*func_sum"',
r'\!DISubprogram\(name: ".*data_parallel_sum"',
r'\!DISubprogram\(name: ".*func_sum\$?\d*"',
r'\!DISubprogram\(name: ".*data_parallel_sum\$\d*"',
]

sig = (f32arrty, f32arrty, f32arrty)
Expand Down