Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,19 @@ fn add_2d_strided(bench: &mut test::Bencher)
});
}

#[bench]
fn add_2d_strided_dyn(bench: &mut test::Bencher)
{
let mut a = Array::<i32, _>::zeros(&[64, 64 * 2][..]);
let mut a = a.slice_mut(s![.., ..;2]);
let b = Array::<i32, _>::zeros(&[64, 64][..]);
let bv = b.view();
bench.iter(|| {
a += &bv;
});
}


#[bench]
fn add_2d_zip_strided(bench: &mut test::Bencher)
{
Expand Down
44 changes: 44 additions & 0 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use itertools::{enumerate, zip};

use {Ix, Ixs, Ix0, Ix1, Ix2, Ix3, IxDyn, Dim, Si};
use IntoDimension;
use RemoveAxis;
use {ArrayView1, ArrayViewMut1};
use {zipsl, zipsl_mut, ZipExt};
use Axis;
Expand Down Expand Up @@ -64,6 +65,9 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
/// - and so on..
/// - For `Vec<Ix>`: `Vec<usize>`,
type Pattern: IntoDimension<Dim=Self>;
// Next smaller dimension (if it exists)
#[doc(hidden)]
type TrySmaller: Dimension;
#[doc(hidden)]
fn ndim(&self) -> usize;

Expand Down Expand Up @@ -341,6 +345,9 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
.max_by_key(|ax| ax.stride().abs())
.map_or(Axis(0), |ax| ax.axis())
}

#[doc(hidden)]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller;
}

// utility functions
Expand All @@ -361,6 +368,7 @@ fn abs_index(len: Ixs, index: Ixs) -> Ix {
unsafe impl Dimension for Dim<[Ix; 0]> {
type SliceArg = [Si; 0];
type Pattern = ();
type TrySmaller = Self;
// empty product is 1 -> size is 1
#[inline]
fn ndim(&self) -> usize { 0 }
Expand All @@ -376,12 +384,17 @@ unsafe impl Dimension for Dim<[Ix; 0]> {
fn next_for(&self, _index: Self) -> Option<Self> {
None
}
#[inline]
fn try_remove_axis(&self, _ignore: Axis) -> Self::TrySmaller {
*self
}
}


unsafe impl Dimension for Dim<[Ix; 1]> {
type SliceArg = [Si; 1];
type Pattern = Ix;
type TrySmaller = <Self as RemoveAxis>::Smaller;
#[inline]
fn ndim(&self) -> usize { 1 }
#[inline]
Expand Down Expand Up @@ -456,11 +469,16 @@ unsafe impl Dimension for Dim<[Ix; 1]> {
None
}
}
#[inline]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
self.remove_axis(axis)
}
}

unsafe impl Dimension for Dim<[Ix; 2]> {
type SliceArg = [Si; 2];
type Pattern = (Ix, Ix);
type TrySmaller = <Self as RemoveAxis>::Smaller;
#[inline]
fn ndim(&self) -> usize { 2 }
#[inline]
Expand Down Expand Up @@ -601,11 +619,16 @@ unsafe impl Dimension for Dim<[Ix; 2]> {
None
}
}
#[inline]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
self.remove_axis(axis)
}
}

unsafe impl Dimension for Dim<[Ix; 3]> {
type SliceArg = [Si; 3];
type Pattern = (Ix, Ix, Ix);
type TrySmaller = <Self as RemoveAxis>::Smaller;
#[inline]
fn ndim(&self) -> usize { 3 }
#[inline]
Expand Down Expand Up @@ -681,13 +704,18 @@ unsafe impl Dimension for Dim<[Ix; 3]> {
}
order
}
#[inline]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
self.remove_axis(axis)
}
}

macro_rules! large_dim {
($n:expr, $name:ident, $($ix:ident),+) => (
unsafe impl Dimension for Dim<[Ix; $n]> {
type SliceArg = [Si; $n];
type Pattern = ($($ix,)*);
type TrySmaller = <Self as RemoveAxis>::Smaller;
#[inline]
fn ndim(&self) -> usize { $n }
#[inline]
Expand All @@ -698,6 +726,10 @@ macro_rules! large_dim {
fn slice(&self) -> &[Ix] { self.ix() }
#[inline]
fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() }
#[inline]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
self.remove_axis(axis)
}
}
)
}
Expand All @@ -712,13 +744,25 @@ unsafe impl Dimension for IxDyn
{
type SliceArg = [Si];
type Pattern = Self;
type TrySmaller = <Self as RemoveAxis>::Smaller;
#[inline]
fn ndim(&self) -> usize { self.ix().len() }
#[inline]
fn slice(&self) -> &[Ix] { self.ix() }
#[inline]
fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() }
#[inline]
fn into_pattern(self) -> Self::Pattern {
self
}
#[inline]
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
if self.ndim() > 0 {
self.remove_axis(axis)
} else {
self.clone()
}
}
}

impl<J> Index<J> for Dim<Vec<usize>>
Expand Down
7 changes: 6 additions & 1 deletion src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ use super::ZipExt;
use dimension::IntoDimension;
use dimension::{axes_of, Axes, merge_axes, stride_offset};
use iterators::whole_chunks_of;
use iterators::{
new_inner_iter_smaller,
new_inner_iter_smaller_mut,
};

use {
NdIndex,
Expand Down Expand Up @@ -1138,7 +1142,8 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
}
// otherwise, break the arrays up into their inner rows
let mut try_slices = true;
let rows = self.inner_iter_mut().zip(rhs.inner_iter());
let rows = new_inner_iter_smaller_mut(self.view_mut()).zip(
new_inner_iter_smaller(rhs.view()));
for (mut s_row, r_row) in rows {
if try_slices {
if let Some(self_s) = s_row.as_slice_mut() {
Expand Down
49 changes: 49 additions & 0 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,55 @@ impl<'a, A, D> ExactSizeIterator for InnerIterMut<'a, A, D>
}
}

/// Create an InnerIter one dimension smaller than D (if possible)
pub fn new_inner_iter_smaller<A, D>(v: ArrayView<A, D>)
-> InnerIter<A, D::TrySmaller>
where D: Dimension
{
let ndim = v.ndim();
let len;
let stride;
let iter_v;
if ndim == 0 {
len = 1;
stride = 0;
iter_v = v.try_remove_axis(Axis(0))
} else {
len = v.dim.last_elem();
stride = v.strides.last_elem() as isize;
iter_v = v.try_remove_axis(Axis(ndim - 1))
}
InnerIter {
inner_len: len,
inner_stride: stride,
iter: iter_v.into_base_iter(),
}
}

pub fn new_inner_iter_smaller_mut<A, D>(v: ArrayViewMut<A, D>)
-> InnerIterMut<A, D::TrySmaller>
where D: Dimension,
{
let ndim = v.ndim();
let len;
let stride;
let iter_v;
if ndim == 0 {
len = 1;
stride = 0;
iter_v = v.try_remove_axis(Axis(0))
} else {
len = v.dim.last_elem();
stride = v.strides.last_elem() as isize;
iter_v = v.try_remove_axis(Axis(ndim - 1))
}
InnerIterMut {
inner_len: len,
inner_stride: stride,
iter: iter_v.into_base_iter(),
}
}

#[derive(Debug)]
pub struct OuterIterCore<A, D> {
index: Ix,
Expand Down
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,20 @@ impl<A, S, D> ArrayBase<S, D>
row.into_iter_().fold((), |(), elt| f(elt));
}
}

/// Remove array axis `axis` and return the result.
fn try_remove_axis(self, axis: Axis) -> ArrayBase<S, D::TrySmaller>
{
let d = self.dim.try_remove_axis(axis);
let s = self.strides.try_remove_axis(axis);
ArrayBase {
ptr: self.ptr,
data: self.data,
dim: d,
strides: s,
}
}

}


Expand Down
52 changes: 52 additions & 0 deletions tests/ix0.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

extern crate ndarray;

use ndarray::Array;
use ndarray::Ix0;
use ndarray::ShapeBuilder;

#[test]
fn test_ix0() {
let mut a = Array::zeros(Ix0());
assert_eq!(a[()], 0.);
a[()] = 1.;
assert_eq!(a[()], 1.);
assert_eq!(a.len(), 1);
assert_eq!(a.as_slice().unwrap(), &[1.]);

let mut a = Array::zeros(Ix0().f());
assert_eq!(a[()], 0.);
a[()] = 1.;
assert_eq!(a[()], 1.);
assert_eq!(a.len(), 1);
assert_eq!(a.as_slice().unwrap(), &[1.]);
}

#[test]
fn test_ix0_add() {
let mut a = Array::zeros(Ix0());
a += 1.;
assert_eq!(a[()], 1.);
a += 2.;
assert_eq!(a[()], 3.);
}

#[test]
fn test_ix0_add_add() {
let mut a = Array::zeros(Ix0());
a += 1.;
let mut b = Array::zeros(Ix0());
b += 1.;
a += &b;
assert_eq!(a[()], 2.);
}

#[test]
fn test_ix0_add_broad() {
let mut b = Array::from_vec(vec![5., 6.]);
let mut a = Array::zeros(Ix0());
a += 1.;
b += &a;
assert_eq!(b[0], 6.);
assert_eq!(b[1], 7.);
}