Skip to content

Commit bc11a35

Browse files
committed
Improve pairwise_distance workloads
* Use 2D dispatch in kernel impls instead of huge sequential inner loop. * Use nested prange in `numba_mlir_p` impl, `numba` and `numba-dpex` doesn't support nested pranges, but `numba-mlir` does.
1 parent 432d276 commit bc11a35

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
@dpex.kernel
1010
def _pairwise_distance_kernel(X1, X2, D):
1111
i = dpex.get_global_id(0)
12+
j = dpex.get_global_id(1)
1213

13-
X2_rows = X2.shape[0]
1414
X1_cols = X1.shape[1]
15-
for j in range(X2_rows):
16-
d = X1.dtype.type(0.0)
17-
for k in range(X1_cols):
18-
tmp = X1[i, k] - X2[j, k]
19-
d += tmp * tmp
20-
D[i, j] = np.sqrt(d)
15+
16+
d = X1.dtype.type(0.0)
17+
for k in range(X1_cols):
18+
tmp = X1[i, k] - X2[j, k]
19+
d += tmp * tmp
20+
D[i, j] = np.sqrt(d)
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[X1.shape[0],](X1, X2, D)
24+
_pairwise_distance_kernel[(X1.shape[0], X2.shape[0]),](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
@nb.kernel(gpu_fp64_truncate="auto")
1010
def _pairwise_distance_kernel(X1, X2, D):
1111
i = nb.get_global_id(0)
12+
j = nb.get_global_id(1)
1213

13-
X2_rows = X2.shape[0]
1414
X1_cols = X1.shape[1]
15-
for j in range(X2_rows):
16-
d = 0.0
17-
for k in range(X1_cols):
18-
tmp = X1[i, k] - X2[j, k]
19-
d += tmp * tmp
20-
D[i, j] = np.sqrt(d)
15+
16+
d = 0.0
17+
for k in range(X1_cols):
18+
tmp = X1[i, k] - X2[j, k]
19+
d += tmp * tmp
20+
D[i, j] = np.sqrt(d)
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[X1.shape[0], nb.DEFAULT_LOCAL_SIZE](X1, X2, D)
24+
_pairwise_distance_kernel[
25+
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
26+
](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _pairwise_distance(X1, X2, D):
2525
# Outermost parallel loop over the matrix X1
2626
for i in numba.prange(X1_rows):
2727
# Loop over the matrix X2
28-
for j in range(X2_rows):
28+
for j in numba.prange(X2_rows):
2929
d = 0.0
3030
# Compute exclidean distance
3131
for k in range(X1_cols):

0 commit comments

Comments
 (0)