Skip to content

Commit 009989b

Browse files
Implement pairwise distance on 2d grid
1 parent 94c2d62 commit 009989b

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@dpex.kernel
1010
def _pairwise_distance_kernel(X1, X2, D):
11-
i = dpex.get_global_id(0)
12-
j = dpex.get_global_id(1)
11+
i = dpex.get_global_id(1)
12+
j = dpex.get_global_id(0)
1313

1414
X1_cols = X1.shape[1]
1515

@@ -21,4 +21,4 @@ def _pairwise_distance_kernel(X1, X2, D):
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D)
24+
_pairwise_distance_kernel[dpex.Range(X2.shape[0], X1.shape[0])](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@nb.kernel(gpu_fp64_truncate="auto")
1010
def _pairwise_distance_kernel(X1, X2, D):
11-
i = nb.get_global_id(0)
12-
j = nb.get_global_id(1)
11+
i = nb.get_global_id(1)
12+
j = nb.get_global_id(0)
1313

1414
X1_cols = X1.shape[1]
1515

@@ -22,5 +22,5 @@ def _pairwise_distance_kernel(X1, X2, D):
2222

2323
def pairwise_distance(X1, X2, D):
2424
_pairwise_distance_kernel[
25-
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
25+
(X2.shape[0], X1.shape[0]), nb.DEFAULT_LOCAL_SIZE
2626
](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_kernel.hpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,29 @@
66

77
using namespace sycl;
88

9-
#ifdef __DO_FLOAT__
10-
#define SQRT(x) sqrtf(x)
11-
#else
12-
#define SQRT(x) sqrt(x)
13-
#endif
9+
template<typename T>
10+
class PairwiseDistanceKernel;
1411

1512
template <typename FpTy>
1613
void pairwise_distance_impl(queue Queue,
17-
size_t npoints,
14+
size_t x1_npoints,
15+
size_t x2_npoints,
1816
size_t ndims,
1917
const FpTy *p1,
2018
const FpTy *p2,
2119
FpTy *distance_op)
2220
{
2321
Queue.submit([&](handler &h) {
24-
h.parallel_for<class PairwiseDistanceKernel>(
25-
range<1>{npoints}, [=](id<1> myID) {
26-
size_t i = myID[0];
27-
for (size_t j = 0; j < npoints; j++) {
28-
FpTy d = 0.;
29-
for (size_t k = 0; k < ndims; k++) {
30-
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
31-
d += tmp * tmp;
32-
}
33-
if (d != 0.0) {
34-
distance_op[i * npoints + j] = sqrt(d);
35-
}
22+
h.parallel_for<PairwiseDistanceKernel<FpTy>>(
23+
range<2>{x1_npoints, x2_npoints}, [=](id<2> myID) {
24+
auto i = myID[0];
25+
auto j = myID[1];
26+
FpTy d = 0.;
27+
for (size_t k = 0; k < ndims; k++) {
28+
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
29+
d += tmp * tmp;
3630
}
31+
distance_op[i * x2_npoints + j] = sycl::sqrt(d);
3732
});
3833
});
3934

dpbench/benchmarks/pairwise_distance/pairwise_distance_sycl_native_ext/pairwise_distance_sycl/_pairwise_distance_sycl.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,22 @@ void pairwise_distance_sync(dpctl::tensor::usm_ndarray X1,
4949
{
5050
sycl::event res_ev;
5151
auto Queue = X1.get_queue();
52-
auto ndims = 3;
53-
auto npoints = X1.get_size() / ndims;
52+
auto ndims = X1.get_shape(1);
53+
auto x1_npoints = X1.get_shape(0);
54+
auto x2_npoints = X2.get_shape(0);
5455

5556
if (!ensure_compatibility(X1, X2, D))
5657
throw std::runtime_error("Input arrays are not acceptable.");
5758

58-
if (X1.get_typenum() != UAR_DOUBLE || X2.get_typenum() != UAR_DOUBLE) {
59-
throw std::runtime_error("Expected a double precision FP array.");
59+
if (X1.get_typenum() == UAR_FLOAT) {
60+
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, X1.get_data<float>(),
61+
X2.get_data<float>(), D.get_data<float>());
62+
} else if (X1.get_typenum() == UAR_DOUBLE) {
63+
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, X1.get_data<double>(),
64+
X2.get_data<double>(), D.get_data<double>());
65+
} else {
66+
throw std::runtime_error("Expected a double or single precision FP array.");
6067
}
61-
62-
pairwise_distance_impl(Queue, npoints, ndims, X1.get_data<double>(),
63-
X2.get_data<double>(), D.get_data<double>());
6468
}
6569

6670
PYBIND11_MODULE(_pairwise_distance_sycl, m)

0 commit comments

Comments
 (0)