@@ -91,6 +91,64 @@ def test_cumulative_sum(x, data):
9191 idx = out_idx .raw , out = out_val ,
9292 expected = expected )
9393
94+
95+
96+ @pytest .mark .min_version ("2024.12" )
97+ @pytest .mark .unvectorized
98+ @given (
99+ x = hh .arrays (
100+ dtype = hh .numeric_dtypes ,
101+ shape = hh .shapes (min_dims = 1 )),
102+ data = st .data (),
103+ )
104+ def test_cumulative_prod (x , data ):
105+ axes = st .integers (- x .ndim , x .ndim - 1 )
106+ if x .ndim == 1 :
107+ axes = axes | st .none ()
108+ axis = data .draw (axes , label = 'axis' )
109+ _axis , = sh .normalize_axis (axis , x .ndim )
110+ dtype = data .draw (kwarg_dtypes (x .dtype ))
111+ include_initial = data .draw (st .booleans (), label = "include_initial" )
112+
113+ kw = data .draw (
114+ hh .specified_kwargs (
115+ ("axis" , axis , None ),
116+ ("dtype" , dtype , None ),
117+ ("include_initial" , include_initial , False ),
118+ ),
119+ label = "kw" ,
120+ )
121+
122+ out = xp .cumulative_prod (x , ** kw )
123+
124+ expected_shape = list (x .shape )
125+ if include_initial :
126+ expected_shape [_axis ] += 1
127+ expected_shape = tuple (expected_shape )
128+ ph .assert_shape ("cumulative_prod" , out_shape = out .shape , expected = expected_shape )
129+
130+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
131+ if expected_dtype is None :
132+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
133+ # uint32 or uint64), we skip testing the output dtype.
134+ # See https://github.com/data-apis/array-api-tests/issues/106
135+ if x .dtype in dh .uint_dtypes :
136+ assert dh .is_int_dtype (out .dtype ) # sanity check
137+ else :
138+ ph .assert_dtype ("cumulative_prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
139+
140+ scalar_type = dh .get_scalar_type (out .dtype )
141+
142+ for x_idx , out_idx , in iter_indices (x .shape , expected_shape , skip_axes = _axis ):
143+ #x_arr = x[x_idx.raw]
144+ out_arr = out [out_idx .raw ]
145+
146+ if include_initial :
147+ ph .assert_scalar_equals ("cumulative_prod" , type_ = scalar_type , idx = out_idx .raw , out = out_arr [0 ], expected = 0 )
148+
149+ #TODO: add value testing of cumulative_prod
150+
151+
94152def kwarg_dtypes (dtype : DataType ) -> st .SearchStrategy [Optional [DataType ]]:
95153 dtypes = [d2 for d1 , d2 in dh .promotion_table if d1 == dtype ]
96154 dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
0 commit comments