@@ -287,6 +287,45 @@ def test_permute_dims(x, axes):
287287 out_indices = permuted_indices )
288288
289289
290+ @pytest .mark .min_version ("2023.12" )
291+ @given (
292+ x = hh .arrays (dtype = hh .all_dtypes , shape = shared_shapes (min_dims = 1 )),
293+ kw = hh .kwargs (
294+ axis = st .none () | shared_shapes (min_dims = 1 ).flatmap (
295+ lambda s : st .integers (- len (s ), len (s ) - 1 )
296+ )
297+ ),
298+ data = st .data (),
299+ )
300+ def test_repeat (x , kw , data ):
301+ shape = x .shape
302+ axis = kw .get ("axis" , None )
303+ dim = math .prod (shape ) if axis is None else shape [axis ]
304+ repeat_strat = st .integers (1 , 4 )
305+ repeats = data .draw (repeat_strat
306+ | hh .arrays (dtype = hh .int_dtypes , elements = repeat_strat ,
307+ shape = st .sampled_from ([(1 ,), (dim ,)])),
308+ label = "repeats" )
309+ if isinstance (repeats , int ):
310+ n_repitions = dim * repeats
311+ else :
312+ if repeats .shape == (1 ,):
313+ n_repitions = dim * repeats [0 ]
314+ else :
315+ n_repitions = int (xp .sum (repeats ))
316+
317+ out = xp .repeat (x , repeats , ** kw )
318+ ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
319+ if axis is None :
320+ expected_shape = (n_repitions ,)
321+ else :
322+ expected_shape = list (shape )
323+ expected_shape [axis ] = n_repitions
324+ expected_shape = tuple (expected_shape )
325+ ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
326+ # TODO: values testing
327+
328+
290329@st .composite
291330def reshape_shapes (draw , shape ):
292331 size = 1 if len (shape ) == 0 else math .prod (shape )
@@ -298,20 +337,6 @@ def reshape_shapes(draw, shape):
298337 return tuple (rshape )
299338
300339
301- @pytest .mark .min_version ("2023.12" )
302- @given (
303- x = hh .arrays (dtype = hh .all_dtypes , shape = hh .shapes (min_dims = 1 )),
304- repeats = st .integers (1 , 4 ),
305- )
306- def test_repeat (x , repeats ):
307- # TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
308- out = xp .repeat (x , repeats )
309- ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
310- expected_shape = (math .prod (x .shape ) * repeats ,)
311- ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
312- # TODO: values testing
313-
314-
315340@pytest .mark .unvectorized
316341@pytest .mark .skip ("flaky" ) # TODO: fix!
317342@given (
0 commit comments