From 0307411f625e35f7c6a3cf98e7f29f2ac9826e06 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 06:30:43 +0900 Subject: [PATCH 01/12] implement cumprod, add tests --- src/numeric/impl_numeric.rs | 68 +++++++++++++++++++++++++++ tests/numeric.rs | 92 ++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 1 deletion(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index a8a008395..559c57e5c 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -14,6 +14,7 @@ use std::ops::{Add, Div, Mul, Sub}; use crate::imp_prelude::*; use crate::numeric_util; +use crate::ScalarOperand; use crate::Slice; /// # Numerical Methods for Arrays @@ -99,6 +100,73 @@ where sum } + /// Return the cumulative product of elements along a given axis. + /// + /// If `axis` is None, the array is flattened before taking the cumulative product. + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); + /// + /// // Cumulative product along rows (axis 0) + /// assert_eq!( + /// a.cumprod(Some(Axis(0))), + /// arr2(&[[1., 2., 3.], + /// [4., 10., 18.]]) + /// ); + /// + /// // Cumulative product along columns (axis 1) + /// assert_eq!( + /// a.cumprod(Some(Axis(1))), + /// arr2(&[[1., 2., 6.], + /// [4., 20., 120.]]) + /// ); + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn cumprod(&self, axis: Option) -> Array + where + A: Clone + One + Mul + ScalarOperand, + D: Dimension + RemoveAxis, + { + // First check dimensionality + if self.ndim() > 1 && axis.is_none() { + panic!("axis parameter is required for arrays with more than one dimension"); + } + + match axis { + None => { + // This case now only happens for 1D arrays + let mut res = Array::ones(self.raw_dim()); + let mut acc = A::one(); + + for (r, x) in res.iter_mut().zip(self.iter()) { + acc = acc * x.clone(); + *r = acc.clone(); + } + + res + } + Some(axis) => { + let mut res: Array = Array::ones(self.raw_dim()); + + // Process each lane independently + for (mut out_lane, in_lane) in res.lanes_mut(axis).into_iter().zip(self.lanes(axis)) { + let mut acc = A::one(); + for (r, x) in out_lane.iter_mut().zip(in_lane.iter()) { + acc = acc * x.clone(); + *r = acc.clone(); + } + } + + res + } + } + } + /// Return variance of elements in the array. /// /// The variance is computed using the [Welford one-pass diff --git a/tests/numeric.rs b/tests/numeric.rs index 839aba58e..40f6967b6 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -4,7 +4,7 @@ )] use approx::assert_abs_diff_eq; -use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Axis}; +use ndarray::{arr0, arr1, arr2, arr3, array, aview1, Array, Array1, Array2, Array3, Axis}; use std::f64; #[test] @@ -75,6 +75,96 @@ fn sum_mean_prod_empty() assert_eq!(a, None); } +#[test] +fn test_cumprod_1d() +{ + let a = array![1, 2, 3, 4]; + // For 1D arrays, both None and Some(Axis(0)) should work + let result_none = a.cumprod(None); + let result_axis = a.cumprod(Some(Axis(0))); + assert_eq!(result_none, array![1, 2, 6, 24]); + assert_eq!(result_axis, array![1, 2, 6, 24]); +} + +#[test] +fn test_cumprod_2d() +{ + let a = array![[1, 2], [3, 4]]; + + // For 2D arrays, we must specify an axis + let result_axis0 = a.cumprod(Some(Axis(0))); + assert_eq!(result_axis0, array![[1, 2], [3, 8]]); + + let result_axis1 = a.cumprod(Some(Axis(1))); + assert_eq!(result_axis1, array![[1, 2], [3, 12]]); +} + +#[test] +fn test_cumprod_3d() +{ + let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; + + // For 3D arrays, we must specify an axis + let result_axis0 = a.cumprod(Some(Axis(0))); + assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]); + + let result_axis1 = a.cumprod(Some(Axis(1))); + assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]); + + let result_axis2 = a.cumprod(Some(Axis(2))); + assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]); +} + +#[test] +fn test_cumprod_empty() +{ + // For 1D empty array + let a: Array1 = array![]; + let result = a.cumprod(None); + assert_eq!(result, array![]); + + // For 2D empty array, must specify axis + let b: Array2 = Array2::zeros((0, 0)); + let result_axis0 = b.cumprod(Some(Axis(0))); + assert_eq!(result_axis0, Array2::zeros((0, 0))); + let result_axis1 = b.cumprod(Some(Axis(1))); + assert_eq!(result_axis1, Array2::zeros((0, 0))); +} + +#[test] +fn test_cumprod_1_element() +{ + // For 1D array with one element + let a = array![5]; + let result_none = a.cumprod(None); + let result_axis = a.cumprod(Some(Axis(0))); + assert_eq!(result_none, array![5]); + assert_eq!(result_axis, array![5]); + + // For 2D array with one element, must specify axis + let b = array![[5]]; + let result_axis0 = b.cumprod(Some(Axis(0))); + let result_axis1 = b.cumprod(Some(Axis(1))); + assert_eq!(result_axis0, array![[5]]); + assert_eq!(result_axis1, array![[5]]); +} + +#[test] +#[should_panic(expected = "axis parameter is required for arrays with more than one dimension")] +fn test_cumprod_nd_none_axis() +{ + let a = array![[1, 2], [3, 4]]; + let _result = a.cumprod(None); +} + +#[test] +#[should_panic(expected = "index out of bounds")] +fn test_cumprod_axis_out_of_bounds() +{ + let a = array![[1, 2], [3, 4]]; + let _result = a.cumprod(Some(Axis(2))); +} + #[test] #[cfg(feature = "std")] fn var() From eebd1cc743708820d201490cbcd821341c6fb2dc Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 06:46:05 +0900 Subject: [PATCH 02/12] remove unused import --- tests/numeric.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/numeric.rs b/tests/numeric.rs index 40f6967b6..81b852ebb 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -4,7 +4,7 @@ )] use approx::assert_abs_diff_eq; -use ndarray::{arr0, arr1, arr2, arr3, array, aview1, Array, Array1, Array2, Array3, Axis}; +use ndarray::{arr0, arr1, arr2, array, aview1, Array, Array1, Array2, Array3, Axis}; use std::f64; #[test] From 0181719dee7f356790982f2acdc9feeeac4a6e04 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 20:10:19 +0900 Subject: [PATCH 03/12] fold_axis for a bit more vectorized ops --- src/numeric/impl_numeric.rs | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 559c57e5c..8b1347293 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -16,6 +16,7 @@ use crate::imp_prelude::*; use crate::numeric_util; use crate::ScalarOperand; use crate::Slice; +use crate::Zip; /// # Numerical Methods for Arrays impl ArrayBase @@ -132,36 +133,29 @@ where A: Clone + One + Mul + ScalarOperand, D: Dimension + RemoveAxis, { - // First check dimensionality - if self.ndim() > 1 && axis.is_none() { - panic!("axis parameter is required for arrays with more than one dimension"); - } + let mut res = Array::ones(self.raw_dim()); match axis { None => { - // This case now only happens for 1D arrays - let mut res = Array::ones(self.raw_dim()); + // For 1D arrays, use simple iteration let mut acc = A::one(); - - for (r, x) in res.iter_mut().zip(self.iter()) { - acc = acc * x.clone(); + Zip::from(&mut res).and(self).for_each(|r, x| { + acc = acc.clone() * x.clone(); *r = acc.clone(); - } - + }); res } Some(axis) => { - let mut res: Array = Array::ones(self.raw_dim()); + // For nD arrays, use fold_axis approach + // Create accumulator array with one less dimension + let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); - // Process each lane independently - for (mut out_lane, in_lane) in res.lanes_mut(axis).into_iter().zip(self.lanes(axis)) { - let mut acc = A::one(); - for (r, x) in out_lane.iter_mut().zip(in_lane.iter()) { - acc = acc * x.clone(); - *r = acc.clone(); - } + for i in 0..self.len_of(axis) { + // Get view of current slice along axis, and update accumulator element-wise multiplication + let view = self.index_axis(axis, i); + acc = acc * &view; + res.index_axis_mut(axis, i).assign(&acc); } - res } } From 9dfa7ed3e0cb76bda01b698771703408a98e83b3 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 20:18:43 +0900 Subject: [PATCH 04/12] restore missing dim check logic(removed accidently) --- src/numeric/impl_numeric.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 8b1347293..4a8adb655 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -133,6 +133,11 @@ where A: Clone + One + Mul + ScalarOperand, D: Dimension + RemoveAxis, { + // First check dimensionality + if self.ndim() > 1 && axis.is_none() { + panic!("axis parameter is required for arrays with more than one dimension"); + } + let mut res = Array::ones(self.raw_dim()); match axis { From 72449807cc284fbb417947bfe6aefd676daf42c4 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 20:26:40 +0900 Subject: [PATCH 05/12] corrected panic msg --- tests/numeric.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/numeric.rs b/tests/numeric.rs index 81b852ebb..9955728ff 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -158,7 +158,7 @@ fn test_cumprod_nd_none_axis() } #[test] -#[should_panic(expected = "index out of bounds")] +#[should_panic(expected = "assertion failed: axis < self.ndim()")] fn test_cumprod_axis_out_of_bounds() { let a = array![[1, 2], [3, 4]]; From a60bec20720eb8e01f147dfd4b983b10bdc6dcb3 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 19 Mar 2025 20:47:11 +0900 Subject: [PATCH 06/12] sperated check for out of bounds case --- src/numeric/impl_numeric.rs | 5 +++++ tests/numeric.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 4a8adb655..3ecbd02ea 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -151,6 +151,11 @@ where res } Some(axis) => { + // Check if axis is valid before any array operations + if axis.0 >= self.ndim() { + panic!("axis is out of bounds for array of dimension"); + } + // For nD arrays, use fold_axis approach // Create accumulator array with one less dimension let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); diff --git a/tests/numeric.rs b/tests/numeric.rs index 9955728ff..309467467 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -158,7 +158,7 @@ fn test_cumprod_nd_none_axis() } #[test] -#[should_panic(expected = "assertion failed: axis < self.ndim()")] +#[should_panic(expected = "axis is out of bounds for array of dimension")] fn test_cumprod_axis_out_of_bounds() { let a = array![[1, 2], [3, 4]]; From 0a3b95fa7b7c7cbe0386f53483829c2f530371fd Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Tue, 25 Mar 2025 08:47:25 +0900 Subject: [PATCH 07/12] Remove Option handling since axis is always required --- src/numeric/impl_numeric.rs | 51 +++++++++++------------------------- tests/numeric.rs | 52 ++++++++++++------------------------- 2 files changed, 31 insertions(+), 72 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 3ecbd02ea..7555c4524 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -103,8 +103,6 @@ where /// Return the cumulative product of elements along a given axis. /// - /// If `axis` is None, the array is flattened before taking the cumulative product. - /// /// ``` /// use ndarray::{arr2, Axis}; /// @@ -113,14 +111,14 @@ where /// /// // Cumulative product along rows (axis 0) /// assert_eq!( - /// a.cumprod(Some(Axis(0))), + /// a.cumprod(Axis(0)), /// arr2(&[[1., 2., 3.], /// [4., 10., 18.]]) /// ); /// /// // Cumulative product along columns (axis 1) /// assert_eq!( - /// a.cumprod(Some(Axis(1))), + /// a.cumprod(Axis(1)), /// arr2(&[[1., 2., 6.], /// [4., 20., 120.]]) /// ); @@ -128,47 +126,28 @@ where /// /// **Panics** if `axis` is out of bounds. #[track_caller] - pub fn cumprod(&self, axis: Option) -> Array + pub fn cumprod(&self, axis: Axis) -> Array where A: Clone + One + Mul + ScalarOperand, D: Dimension + RemoveAxis, { - // First check dimensionality - if self.ndim() > 1 && axis.is_none() { - panic!("axis parameter is required for arrays with more than one dimension"); + // Check if axis is valid before any array operations + if axis.0 >= self.ndim() { + panic!("axis is out of bounds for array of dimension"); } let mut res = Array::ones(self.raw_dim()); + let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); - match axis { - None => { - // For 1D arrays, use simple iteration - let mut acc = A::one(); - Zip::from(&mut res).and(self).for_each(|r, x| { - acc = acc.clone() * x.clone(); - *r = acc.clone(); - }); - res - } - Some(axis) => { - // Check if axis is valid before any array operations - if axis.0 >= self.ndim() { - panic!("axis is out of bounds for array of dimension"); - } - - // For nD arrays, use fold_axis approach - // Create accumulator array with one less dimension - let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); - - for i in 0..self.len_of(axis) { - // Get view of current slice along axis, and update accumulator element-wise multiplication - let view = self.index_axis(axis, i); - acc = acc * &view; - res.index_axis_mut(axis, i).assign(&acc); - } - res - } + // Use fold_axis approach + for i in 0..self.len_of(axis) { + // Get view of current slice along axis, and update accumulator element-wise multiplication + let view = self.index_axis(axis, i); + acc = acc * &view; + res.index_axis_mut(axis, i).assign(&acc); } + + res } /// Return variance of elements in the array. diff --git a/tests/numeric.rs b/tests/numeric.rs index 309467467..7e6964812 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -79,11 +79,8 @@ fn sum_mean_prod_empty() fn test_cumprod_1d() { let a = array![1, 2, 3, 4]; - // For 1D arrays, both None and Some(Axis(0)) should work - let result_none = a.cumprod(None); - let result_axis = a.cumprod(Some(Axis(0))); - assert_eq!(result_none, array![1, 2, 6, 24]); - assert_eq!(result_axis, array![1, 2, 6, 24]); + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![1, 2, 6, 24]); } #[test] @@ -91,11 +88,10 @@ fn test_cumprod_2d() { let a = array![[1, 2], [3, 4]]; - // For 2D arrays, we must specify an axis - let result_axis0 = a.cumprod(Some(Axis(0))); + let result_axis0 = a.cumprod(Axis(0)); assert_eq!(result_axis0, array![[1, 2], [3, 8]]); - let result_axis1 = a.cumprod(Some(Axis(1))); + let result_axis1 = a.cumprod(Axis(1)); assert_eq!(result_axis1, array![[1, 2], [3, 12]]); } @@ -104,30 +100,24 @@ fn test_cumprod_3d() { let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; - // For 3D arrays, we must specify an axis - let result_axis0 = a.cumprod(Some(Axis(0))); + let result_axis0 = a.cumprod(Axis(0)); assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]); - let result_axis1 = a.cumprod(Some(Axis(1))); + let result_axis1 = a.cumprod(Axis(1)); assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]); - let result_axis2 = a.cumprod(Some(Axis(2))); + let result_axis2 = a.cumprod(Axis(2)); assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]); } #[test] fn test_cumprod_empty() { - // For 1D empty array - let a: Array1 = array![]; - let result = a.cumprod(None); - assert_eq!(result, array![]); - - // For 2D empty array, must specify axis + // For 2D empty array let b: Array2 = Array2::zeros((0, 0)); - let result_axis0 = b.cumprod(Some(Axis(0))); + let result_axis0 = b.cumprod(Axis(0)); assert_eq!(result_axis0, Array2::zeros((0, 0))); - let result_axis1 = b.cumprod(Some(Axis(1))); + let result_axis1 = b.cumprod(Axis(1)); assert_eq!(result_axis1, Array2::zeros((0, 0))); } @@ -136,33 +126,23 @@ fn test_cumprod_1_element() { // For 1D array with one element let a = array![5]; - let result_none = a.cumprod(None); - let result_axis = a.cumprod(Some(Axis(0))); - assert_eq!(result_none, array![5]); - assert_eq!(result_axis, array![5]); + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![5]); - // For 2D array with one element, must specify axis + // For 2D array with one element let b = array![[5]]; - let result_axis0 = b.cumprod(Some(Axis(0))); - let result_axis1 = b.cumprod(Some(Axis(1))); + let result_axis0 = b.cumprod(Axis(0)); + let result_axis1 = b.cumprod(Axis(1)); assert_eq!(result_axis0, array![[5]]); assert_eq!(result_axis1, array![[5]]); } -#[test] -#[should_panic(expected = "axis parameter is required for arrays with more than one dimension")] -fn test_cumprod_nd_none_axis() -{ - let a = array![[1, 2], [3, 4]]; - let _result = a.cumprod(None); -} - #[test] #[should_panic(expected = "axis is out of bounds for array of dimension")] fn test_cumprod_axis_out_of_bounds() { let a = array![[1, 2], [3, 4]]; - let _result = a.cumprod(Some(Axis(2))); + let _result = a.cumprod(Axis(2)); } #[test] From 8dde517b43e88996a6090a0664c0086b6404d516 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Tue, 25 Mar 2025 08:59:14 +0900 Subject: [PATCH 08/12] apply Zip instead of naive for loop --- src/numeric/impl_numeric.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 7555c4524..740d0c736 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -139,13 +139,12 @@ where let mut res = Array::ones(self.raw_dim()); let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); - // Use fold_axis approach - for i in 0..self.len_of(axis) { - // Get view of current slice along axis, and update accumulator element-wise multiplication - let view = self.index_axis(axis, i); - acc = acc * &view; - res.index_axis_mut(axis, i).assign(&acc); - } + Zip::from(self.axis_iter(axis)) + .and(res.axis_iter_mut(axis)) + .for_each(|view, mut res| { + acc = acc.clone() * &view; + res.assign(&acc); + }); res } From 2f0a09d389a1461d24af2f3da2784002ac8764fc Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Tue, 25 Mar 2025 09:29:11 +0900 Subject: [PATCH 09/12] rename var, remove redandunt cloning --- src/numeric/impl_numeric.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 740d0c736..0ab6f8fba 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -137,13 +137,13 @@ where } let mut res = Array::ones(self.raw_dim()); - let mut acc = Array::ones(self.raw_dim().remove_axis(axis)); + let mut running_product = Array::ones(self.raw_dim().remove_axis(axis)); Zip::from(self.axis_iter(axis)) .and(res.axis_iter_mut(axis)) .for_each(|view, mut res| { - acc = acc.clone() * &view; - res.assign(&acc); + running_product = &running_product * &view; + res.assign(&running_product); }); res From 56774bf5d6ed0cc49a9a029652b3226298422733 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Tue, 25 Mar 2025 10:46:16 +0900 Subject: [PATCH 10/12] closure to remove clones --- src/numeric/impl_numeric.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 0ab6f8fba..57eceac27 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -137,13 +137,14 @@ where } let mut res = Array::ones(self.raw_dim()); - let mut running_product = Array::ones(self.raw_dim().remove_axis(axis)); + let running_product = Array::ones(self.raw_dim().remove_axis(axis)); Zip::from(self.axis_iter(axis)) .and(res.axis_iter_mut(axis)) - .for_each(|view, mut res| { - running_product = &running_product * &view; + .fold(running_product, |mut running_product, view, mut res| { + running_product = running_product * &view; res.assign(&running_product); + running_product }); res From 58d7e8a9ef03c3a168d87114b24d4a62c51d957f Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Wed, 26 Mar 2025 09:42:38 +0900 Subject: [PATCH 11/12] using accumulate_axis_inplace --- src/numeric/impl_numeric.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 302e607b8..aca7aeecf 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -14,9 +14,7 @@ use std::ops::{Add, Div, Mul, Sub}; use crate::imp_prelude::*; use crate::numeric_util; -use crate::ScalarOperand; use crate::Slice; -use crate::Zip; /// # Numerical Methods for Arrays impl ArrayRef @@ -126,26 +124,18 @@ where D: Dimension #[track_caller] pub fn cumprod(&self, axis: Axis) -> Array where - A: Clone + One + Mul + ScalarOperand, + A: Copy + Clone + Mul, D: Dimension + RemoveAxis, { - // Check if axis is valid before any array operations if axis.0 >= self.ndim() { panic!("axis is out of bounds for array of dimension"); } - let mut res = Array::ones(self.raw_dim()); - let running_product = Array::ones(self.raw_dim().remove_axis(axis)); - - Zip::from(self.axis_iter(axis)) - .and(res.axis_iter_mut(axis)) - .fold(running_product, |mut running_product, view, mut res| { - running_product = running_product * &view; - res.assign(&running_product); - running_product - }); - - res + let mut result = self.to_owned(); + result.accumulate_axis_inplace(axis, |&prev, curr| { + *curr = *curr * prev; + }); + result } /// Return variance of elements in the array. From f3173b83e2a3617a1f951d4c0830e5deaf2809fa Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Thu, 27 Mar 2025 07:50:32 +0900 Subject: [PATCH 12/12] mulassign and clone once --- src/numeric/impl_numeric.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index aca7aeecf..ae82a482a 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -10,7 +10,7 @@ use num_traits::Float; use num_traits::One; use num_traits::{FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, MulAssign, Sub}; use crate::imp_prelude::*; use crate::numeric_util; @@ -124,7 +124,7 @@ where D: Dimension #[track_caller] pub fn cumprod(&self, axis: Axis) -> Array where - A: Copy + Clone + Mul, + A: Clone + Mul + MulAssign, D: Dimension + RemoveAxis, { if axis.0 >= self.ndim() { @@ -132,9 +132,7 @@ where D: Dimension } let mut result = self.to_owned(); - result.accumulate_axis_inplace(axis, |&prev, curr| { - *curr = *curr * prev; - }); + result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone()); result }