@@ -35,6 +35,19 @@ def set_last_one_item(item: Item, a):
3535 a [i ] = 1
3636
3737
38+ @dpex_exp .kernel
39+ def set_last_one_linear_item (item : Item , a ):
40+ i = item .get_linear_range () - 1
41+ a [i ] = 1
42+
43+
44+ @dpex_exp .kernel
45+ def set_last_one_linear_nd_item (nd_item : NdItem , a ):
46+ i = nd_item .get_global_linear_range () - 1
47+ a [0 ] = i
48+ a [i ] = 1
49+
50+
3851@dpex_exp .kernel
3952def set_last_one_nd_item (item : NdItem , a ):
4053 if item .get_global_id (0 ) == 0 :
@@ -43,6 +56,20 @@ def set_last_one_nd_item(item: NdItem, a):
4356 a [i ] = 1
4457
4558
59+ @dpex_exp .kernel
60+ def set_last_group_one_linear_nd_item (nd_item : NdItem , a ):
61+ i = nd_item .get_local_linear_range () - 1
62+ a [0 ] = i
63+ a [i ] = 1
64+
65+
66+ @dpex_exp .kernel
67+ def set_last_group_one_group_linear_nd_item (nd_item : NdItem , a ):
68+ i = nd_item .get_group ().get_local_linear_range () - 1
69+ a [0 ] = i
70+ a [i ] = 1
71+
72+
4673@dpex_exp .kernel
4774def set_last_group_one_nd_item (item : NdItem , a ):
4875 if item .get_global_id (0 ) == 0 :
@@ -99,6 +126,12 @@ def _get_group_range_driver(nditem: NdItem, a):
99126 a [i ] = g .get_group_range (0 )
100127
101128
129+ def _get_group_linear_range_driver (nditem : NdItem , a ):
130+ i = nditem .get_global_linear_id ()
131+ g = nditem .get_group ()
132+ a [i ] = g .get_group_linear_range ()
133+
134+
102135def _get_group_local_range_driver (nditem : NdItem , a ):
103136 i = nditem .get_global_id (0 )
104137 g = nditem .get_group ()
@@ -122,11 +155,34 @@ def test_item_get_range():
122155 assert np .array_equal (a .asnumpy (), want )
123156
124157
125- def test_nd_item_get_global_range ():
158+ @pytest .mark .parametrize (
159+ "rng" ,
160+ [dpex .Range (_SIZE ), dpex .Range (1 , _GROUP_SIZE , int (_SIZE / _GROUP_SIZE ))],
161+ )
162+ def test_item_get_linear_range (rng ):
126163 a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
127- dpex_exp .call_kernel (
128- set_last_one_nd_item , dpex .NdRange ((a .size ,), (_GROUP_SIZE ,)), a
129- )
164+ dpex_exp .call_kernel (set_last_one_linear_item , rng , a )
165+
166+ want = np .zeros (a .size , dtype = np .float32 )
167+ want [- 1 ] = 1
168+
169+ assert np .array_equal (a .asnumpy (), want )
170+
171+
172+ @pytest .mark .parametrize (
173+ "kernel,rng" ,
174+ [
175+ (set_last_one_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
176+ (set_last_one_linear_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
177+ (
178+ set_last_one_linear_nd_item ,
179+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
180+ ),
181+ ],
182+ )
183+ def test_nd_item_get_global_range (kernel , rng ):
184+ a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
185+ dpex_exp .call_kernel (kernel , rng , a )
130186
131187 want = np .zeros (a .size , dtype = np .float32 )
132188 want [- 1 ] = 1
@@ -135,11 +191,31 @@ def test_nd_item_get_global_range():
135191 assert np .array_equal (a .asnumpy (), want )
136192
137193
138- def test_nd_item_get_local_range ():
194+ @pytest .mark .parametrize (
195+ "kernel,rng" ,
196+ [
197+ (set_last_group_one_nd_item , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
198+ (
199+ set_last_group_one_linear_nd_item ,
200+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
201+ ),
202+ (
203+ set_last_group_one_linear_nd_item ,
204+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
205+ ),
206+ (
207+ set_last_group_one_group_linear_nd_item ,
208+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
209+ ),
210+ (
211+ set_last_group_one_group_linear_nd_item ,
212+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
213+ ),
214+ ],
215+ )
216+ def test_nd_item_get_local_range (kernel , rng ):
139217 a = dpnp .zeros (_SIZE , dtype = dpnp .float32 )
140- dpex_exp .call_kernel (
141- set_last_group_one_nd_item , dpex .NdRange ((a .size ,), (_GROUP_SIZE ,)), a
142- )
218+ dpex_exp .call_kernel (kernel , rng , a )
143219
144220 want = np .zeros (a .size , dtype = np .float32 )
145221 want [_GROUP_SIZE - 1 ] = 1
@@ -240,21 +316,32 @@ def test_get_group_id(driver, rng):
240316 assert np .array_equal (ka .asnumpy (), expected )
241317
242318
243- def test_get_group_range ():
244- global_size = 100
245- group_size = 20
246- num_groups = global_size // group_size
319+ @pytest .mark .parametrize (
320+ "driver,rng" ,
321+ [
322+ (_get_group_range_driver , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
323+ (
324+ _get_group_linear_range_driver ,
325+ dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,)),
326+ ),
327+ (
328+ _get_group_linear_range_driver ,
329+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
330+ ),
331+ ],
332+ )
333+ def test_get_group_range (driver , rng ):
334+ num_groups = _SIZE // _GROUP_SIZE
247335
248- a = dpnp .empty (global_size , dtype = dpnp .int32 )
249- ka = dpnp .empty (global_size , dtype = dpnp .int32 )
250- expected = np .empty (global_size , dtype = np .int32 )
251- ndrange = NdRange ((global_size ,), (group_size ,))
252- dpex_exp .call_kernel (dpex_exp .kernel (_get_group_range_driver ), ndrange , a )
253- kapi_call_kernel (_get_group_range_driver , ndrange , ka )
336+ a = dpnp .empty (_SIZE , dtype = dpnp .int32 )
337+ ka = dpnp .empty (_SIZE , dtype = dpnp .int32 )
338+ expected = np .empty (_SIZE , dtype = np .int32 )
339+ dpex_exp .call_kernel (dpex_exp .kernel (driver ), rng , a )
340+ kapi_call_kernel (driver , rng , ka )
254341
255342 for gid in range (num_groups ):
256- for lid in range (group_size ):
257- expected [gid * group_size + lid ] = num_groups
343+ for lid in range (_GROUP_SIZE ):
344+ expected [gid * _GROUP_SIZE + lid ] = num_groups
258345
259346 assert np .array_equal (a .asnumpy (), expected )
260347 assert np .array_equal (ka .asnumpy (), expected )
0 commit comments