Skip to content

Commit 7160c2a

Browse files
Add support of fp32 for sycl implementations
1 parent 94c2d62 commit 7160c2a

File tree

8 files changed

+86
-76
lines changed

8 files changed

+86
-76
lines changed

dpbench/benchmarks/black_scholes/black_scholes_sycl_native_ext/black_scholes_sycl/_black_scholes_kernel.hpp

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,11 @@
66
#include <stdlib.h>
77
#include <type_traits>
88

9-
#ifdef __DO_FLOAT__
10-
#define EXP(x) expf(x)
11-
#define LOG(x) logf(x)
12-
#define SQRT(x) sqrtf(x)
13-
#define ERF(x) erff(x)
14-
#define INVSQRT(x) 1.0f / sqrtf(x)
15-
16-
#define QUARTER 0.25f
17-
#define HALF 0.5f
18-
#define TWO 2.0f
19-
#else
20-
#define EXP(x) sycl::exp(x)
21-
#define LOG(x) sycl::log(x)
22-
#define SQRT(x) sycl::sqrt(x)
23-
#define ERF(x) sycl::erf(x)
24-
#define INVSQRT(x) 1.0 / sycl::sqrt(x)
25-
26-
#define QUARTER 0.25
27-
#define HALF 0.5
28-
#define TWO 2.0
29-
#endif
30-
319
using namespace sycl;
3210

11+
template <typename FpTy>
12+
class BlackScholesKernel;
13+
3314
template <typename FpTy>
3415
void black_scholes_impl(queue Queue,
3516
size_t nopt,
@@ -41,27 +22,30 @@ void black_scholes_impl(queue Queue,
4122
FpTy *call,
4223
FpTy *put)
4324
{
25+
constexpr FpTy _0_25 = 0.25;
26+
constexpr FpTy _0_5 = 0.5;
27+
4428
auto e = Queue.submit([&](handler &h) {
45-
h.parallel_for<class BlackScholesKernel>(
29+
h.parallel_for<BlackScholesKernel<FpTy>>(
4630
range<1>{nopt}, [=](id<1> myID) {
4731
FpTy mr = -rate;
48-
FpTy sig_sig_two = volatility * volatility * TWO;
32+
FpTy sig_sig_two = volatility * volatility * 2;
4933
int i = myID[0];
5034
FpTy a, b, c, y, z, e;
5135
FpTy d1, d2, w1, w2;
5236

53-
a = LOG(price[i] / strike[i]);
37+
a = sycl::log(price[i] / strike[i]);
5438
b = t[i] * mr;
5539
z = t[i] * sig_sig_two;
56-
c = QUARTER * z;
57-
y = INVSQRT(z);
40+
c = _0_25 * z;
41+
y = sycl::rsqrt(z);
5842
w1 = (a - b + c) * y;
5943
w2 = (a - b - c) * y;
60-
d1 = ERF(w1);
61-
d2 = ERF(w2);
62-
d1 = HALF + HALF * d1;
63-
d2 = HALF + HALF * d2;
64-
e = EXP(b);
44+
d1 = sycl::erf(w1);
45+
d2 = sycl::erf(w2);
46+
d1 = _0_5 + _0_5 * d1;
47+
d2 = _0_5 + _0_5 * d2;
48+
e = sycl::exp(b);
6549
call[i] = price[i] * d1 - strike[i] * e * d2;
6650
put[i] = call[i] - price[i] + strike[i] * e;
6751
});

dpbench/benchmarks/black_scholes/black_scholes_sycl_native_ext/black_scholes_sycl/_black_scholes_sycl.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,19 @@ void black_scholes_sync(size_t /**/,
6464
if (!ensure_compatibility(price, strike, t, call, put))
6565
throw std::runtime_error("Input arrays are not acceptable.");
6666

67-
if (typenum != UAR_DOUBLE) {
68-
throw std::runtime_error("Expected a double precision FP array.");
67+
if (typenum == UAR_FLOAT) {
68+
black_scholes_impl<float>(Queue, nopt, price.get_data<float>(),
69+
strike.get_data<float>(), t.get_data<float>(), rate,
70+
volatility, call.get_data<float>(),
71+
put.get_data<float>());
72+
} else if (typenum == UAR_DOUBLE) {
73+
black_scholes_impl<double>(Queue, nopt, price.get_data<double>(),
74+
strike.get_data<double>(), t.get_data<double>(), rate,
75+
volatility, call.get_data<double>(),
76+
put.get_data<double>());
77+
} else {
78+
throw std::runtime_error("Expected a double or single precision FP array.");
6979
}
70-
71-
black_scholes_impl(Queue, nopt, price.get_data<double>(),
72-
strike.get_data<double>(), t.get_data<double>(), rate,
73-
volatility, call.get_data<double>(),
74-
put.get_data<double>());
7580
}
7681

7782
PYBIND11_MODULE(_black_scholes_sycl, m)

dpbench/benchmarks/dbscan/dbscan_sycl_native_ext/dbscan_sycl/_dbscan_kernel.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ void getNeighborhood(size_t n,
106106
}
107107
}
108108

109+
template<typename FpTy>
110+
class DBScanKernel;
111+
109112
template <typename FpTy>
110113
size_t dbscan_impl(queue q,
111114
size_t n_samples,
@@ -126,14 +129,14 @@ size_t dbscan_impl(queue q,
126129
q.wait();
127130

128131
auto e = q.submit([&](handler &h) {
129-
h.parallel_for<class DBScanKernel>(
132+
h.parallel_for<DBScanKernel<FpTy>>(
130133
range<1>{n_samples}, [=](id<1> myID) {
131134
size_t i1 = myID[0];
132135
size_t i2 = (i1 + 1 == n_samples ? n_samples : i1 + 1);
133-
getNeighborhood<double>(n_samples, n_features, data, i2 - i1,
134-
data + i1 * n_features, eps,
135-
d_indices + i1 * n_samples,
136-
d_sizes + i1);
136+
getNeighborhood<FpTy>(n_samples, n_features, data, i2 - i1,
137+
data + i1 * n_features, eps,
138+
d_indices + i1 * n_samples,
139+
d_sizes + i1);
137140
});
138141
});
139142

dpbench/benchmarks/dbscan/dbscan_sycl_native_ext/dbscan_sycl/_dbscan_sycl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,20 @@ size_t dbscan_sync(size_t n_samples,
3838
size_t min_pts)
3939
{
4040
auto queue = data.get_queue();
41+
auto typenum = data.get_typenum();
4142

4243
if (!ensure_compatibility(data))
4344
throw std::runtime_error("Input arrays are not acceptable.");
4445

45-
if (data.get_typenum() != UAR_DOUBLE) {
46-
throw std::runtime_error("Expected a double precision FP array.");
46+
if (typenum == UAR_FLOAT) {
47+
return dbscan_impl<float>(queue, n_samples, n_features,
48+
data.get_data<float>(), eps, min_pts);
49+
} else if (typenum == UAR_DOUBLE) {
50+
return dbscan_impl<double>(queue, n_samples, n_features,
51+
data.get_data<double>(), eps, min_pts);
4752
}
4853

49-
return dbscan_impl<double>(queue, n_samples, n_features,
50-
data.get_data<double>(), eps, min_pts);
54+
throw std::runtime_error("Expected a double or single precision FP array.");
5155
}
5256

5357
PYBIND11_MODULE(_dbscan_sycl, m)

dpbench/benchmarks/l2_norm/l2_norm_sycl_native_ext/l2_norm_sycl/_l2_norm_kernel.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
using namespace sycl;
1111

12+
template <typename FpTy>
13+
class theKernel;
14+
1215
template <typename FpTy>
1316
void l2_norm_impl(queue Queue,
1417
size_t npoints,
@@ -18,7 +21,7 @@ void l2_norm_impl(queue Queue,
1821
{
1922
Queue
2023
.submit([&](handler &h) {
21-
h.parallel_for<class theKernel>(range<1>{npoints}, [=](id<1> myID) {
24+
h.parallel_for<theKernel<FpTy>>(range<1>{npoints}, [=](id<1> myID) {
2225
size_t i = myID[0];
2326
for (size_t k = 0; k < dims; k++) {
2427
d[i] += a[i * dims + k] * a[i * dims + k];

dpbench/benchmarks/l2_norm/l2_norm_sycl_native_ext/l2_norm_sycl/_l2_norm_sycl.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ void l2_norm_sync(dpctl::tensor::usm_ndarray a, dpctl::tensor::usm_ndarray d)
1919

2020
auto dims = 3;
2121
auto npoints = a.get_size() / dims;
22-
23-
if (a.get_typenum() != UAR_DOUBLE) {
24-
throw std::runtime_error("Expected a double precision FP array.");
22+
auto typenum = a.get_typenum();
23+
24+
if (typenum == UAR_FLOAT) {
25+
l2_norm_impl(Queue, npoints, dims, a.get_data<float>(),
26+
d.get_data<float>());
27+
} else if (typenum == UAR_DOUBLE) {
28+
l2_norm_impl(Queue, npoints, dims, a.get_data<double>(),
29+
d.get_data<double>());
30+
} else {
31+
throw std::runtime_error("Expected a double or single precision FP array.");
2532
}
26-
27-
l2_norm_impl(Queue, npoints, dims, a.get_data<double>(),
28-
d.get_data<double>());
2933
}
3034

3135
PYBIND11_MODULE(_l2_norm_sycl, m)

dpbench/benchmarks/rambo/rambo_sycl_native_ext/rambo_sycl/_rambo_kernel.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
#include <stdlib.h>
1111
#include <type_traits>
1212

13-
#define SIN(x) sycl::sin(x)
14-
#define COS(x) sycl::cos(x)
15-
#define SQRT(x) sycl::sqrt(x)
16-
#define LOG(x) sycl::log(x)
17-
1813
using namespace sycl;
1914

15+
template <typename FpTy>
16+
class RamboKernel;
17+
2018
template <typename FpTy>
2119
event rambo_impl(queue Queue,
2220
size_t nevts,
@@ -26,20 +24,21 @@ event rambo_impl(queue Queue,
2624
const FpTy *usmQ1,
2725
FpTy *usmOutput)
2826
{
27+
constexpr FpTy pi_v = M_PI;
2928
return Queue.submit([&](handler &h) {
30-
h.parallel_for<class RamboKernel>(range<1>{nevts}, [=](id<1> myID) {
29+
h.parallel_for<RamboKernel<FpTy>>(range<1>{nevts}, [=](id<1> myID) {
3130
for (size_t j = 0; j < nout; j++) {
3231
int i = myID[0];
3332
size_t idx = i * nout + j;
3433

35-
FpTy C = 2.0 * usmC1[idx] - 1.0;
36-
FpTy S = SQRT(1 - C * C);
37-
FpTy F = 2.0 * M_PI * usmF1[idx];
38-
FpTy Q = -LOG(usmQ1[idx]);
34+
FpTy C = 2 * usmC1[idx] - 1;
35+
FpTy S = sycl::sqrt(1 - C * C);
36+
FpTy F = 2 * pi_v * usmF1[idx];
37+
FpTy Q = -sycl::log(usmQ1[idx]);
3938

4039
usmOutput[idx * 4] = Q;
41-
usmOutput[idx * 4 + 1] = Q * S * SIN(F);
42-
usmOutput[idx * 4 + 2] = Q * S * COS(F);
40+
usmOutput[idx * 4 + 1] = Q * S * sycl::sin(F);
41+
usmOutput[idx * 4 + 2] = Q * S * sycl::cos(F);
4342
usmOutput[idx * 4 + 3] = Q * C;
4443
}
4544
});

dpbench/benchmarks/rambo/rambo_sycl_native_ext/rambo_sycl/_rambo_sycl.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,24 @@ void rambo_sync(size_t nevts,
5555
if (!ensure_compatibility(C1, F1, Q1))
5656
throw std::runtime_error("Input arrays are not acceptable.");
5757

58-
if (C1.get_typenum() != UAR_DOUBLE || F1.get_typenum() != UAR_DOUBLE ||
59-
Q1.get_typenum() != UAR_DOUBLE || output.get_typenum() != UAR_DOUBLE)
60-
{
61-
throw std::runtime_error("Expected a double precision FP array.");
62-
}
58+
if (output.get_typenum() != C1.get_typenum())
59+
throw std::runtime_error("Input arrays are not acceptable.");
6360

64-
auto e = rambo_impl(Queue, nevts, nout, C1.get_data<double>(),
65-
F1.get_data<double>(), Q1.get_data<double>(),
66-
output.get_data<double>());
67-
e.wait();
61+
auto typenum = C1.get_typenum();
62+
63+
if (typenum == UAR_FLOAT) {
64+
auto e = rambo_impl(Queue, nevts, nout, C1.get_data<float>(),
65+
F1.get_data<float>(), Q1.get_data<float>(),
66+
output.get_data<float>());
67+
e.wait();
68+
} else if (typenum == UAR_DOUBLE) {
69+
auto e = rambo_impl(Queue, nevts, nout, C1.get_data<double>(),
70+
F1.get_data<double>(), Q1.get_data<double>(),
71+
output.get_data<double>());
72+
e.wait();
73+
} else {
74+
throw std::runtime_error("Expected a double or single precision FP array.");
75+
}
6876
}
6977

7078
PYBIND11_MODULE(_rambo_sycl, m)

0 commit comments

Comments
 (0)