22#
33# SPDX-License-Identifier: Apache-2.0
44
5- import math
6-
75import numba_dpex as dpex
8- import numpy as np
96
107# This implementation is numba dpex kernel version with atomics.
118
@@ -27,6 +24,7 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
2724 rbins_squared ,
2825 result ,
2926):
27+ dtype = x0 .dtype
3028 lid0 = dpex .get_local_id (0 )
3129 gr0 = dpex .get_group_id (0 )
3230
@@ -38,9 +36,9 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
3836
3937 n_wi = 20
4038
41- dsq_mat = dpex .private .array (shape = (20 * 20 ), dtype = np . float32 )
42- w0_vec = dpex .private .array (shape = (20 ), dtype = np . float32 )
43- w1_vec = dpex .private .array (shape = (20 ), dtype = np . float32 )
39+ dsq_mat = dpex .private .array (shape = (20 * 20 ), dtype = dtype )
40+ w0_vec = dpex .private .array (shape = (20 ), dtype = dtype )
41+ w1_vec = dpex .private .array (shape = (20 ), dtype = dtype )
4442
4543 offset0 = gr0 * n_wi * lws0 + lid0
4644 offset1 = gr1 * n_wi * lws1 + lid1
@@ -80,7 +78,7 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
8078
8179 # update slm_hist. Use work-item private buffer of 16 tfloat elements
8280 for k in range (0 , slm_hist_size , private_hist_size ):
83- private_hist = dpex .private .array (shape = (16 ), dtype = np . float32 )
81+ private_hist = dpex .private .array (shape = (16 ), dtype = dtype )
8482 for p in range (private_hist_size ):
8583 private_hist [p ] = 0.0
8684
@@ -95,7 +93,9 @@ def count_weighted_pairs_3d_intel_no_slm_ker(
9593 pk = k
9694 for p in range (private_hist_size ):
9795 private_hist [p ] += (
98- pw if (pk < nbins and dsq <= rbins_squared [pk ]) else 0.0
96+ pw
97+ if (pk < nbins and dsq <= rbins_squared [pk ])
98+ else dtype .type (0.0 )
9999 )
100100 pk += 1
101101
0 commit comments