Skip to content

Commit 1923258

Browse files
authored
Merge pull request #290 from bluss/try-smaller
Code size improvement in .zip_mut_with()
2 parents 2ba45b7 + da78e30 commit 1923258

File tree

6 files changed

+178
-1
lines changed

6 files changed

+178
-1
lines changed

benches/bench1.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,19 @@ fn add_2d_strided(bench: &mut test::Bencher)
438438
});
439439
}
440440

441+
#[bench]
442+
fn add_2d_strided_dyn(bench: &mut test::Bencher)
443+
{
444+
let mut a = Array::<i32, _>::zeros(&[64, 64 * 2][..]);
445+
let mut a = a.slice_mut(s![.., ..;2]);
446+
let b = Array::<i32, _>::zeros(&[64, 64][..]);
447+
let bv = b.view();
448+
bench.iter(|| {
449+
a += &bv;
450+
});
451+
}
452+
453+
441454
#[bench]
442455
fn add_2d_zip_strided(bench: &mut test::Bencher)
443456
{

src/dimension/dimension_trait.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use itertools::{enumerate, zip};
1515

1616
use {Ix, Ixs, Ix0, Ix1, Ix2, Ix3, IxDyn, Dim, Si};
1717
use IntoDimension;
18+
use RemoveAxis;
1819
use {ArrayView1, ArrayViewMut1};
1920
use {zipsl, zipsl_mut, ZipExt};
2021
use Axis;
@@ -64,6 +65,9 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
6465
/// - and so on..
6566
/// - For `Vec<Ix>`: `Vec<usize>`,
6667
type Pattern: IntoDimension<Dim=Self>;
68+
// Next smaller dimension (if it exists)
69+
#[doc(hidden)]
70+
type TrySmaller: Dimension;
6771
#[doc(hidden)]
6872
fn ndim(&self) -> usize;
6973

@@ -341,6 +345,9 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
341345
.max_by_key(|ax| ax.stride().abs())
342346
.map_or(Axis(0), |ax| ax.axis())
343347
}
348+
349+
#[doc(hidden)]
350+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller;
344351
}
345352

346353
// utility functions
@@ -361,6 +368,7 @@ fn abs_index(len: Ixs, index: Ixs) -> Ix {
361368
unsafe impl Dimension for Dim<[Ix; 0]> {
362369
type SliceArg = [Si; 0];
363370
type Pattern = ();
371+
type TrySmaller = Self;
364372
// empty product is 1 -> size is 1
365373
#[inline]
366374
fn ndim(&self) -> usize { 0 }
@@ -376,12 +384,17 @@ unsafe impl Dimension for Dim<[Ix; 0]> {
376384
fn next_for(&self, _index: Self) -> Option<Self> {
377385
None
378386
}
387+
#[inline]
388+
fn try_remove_axis(&self, _ignore: Axis) -> Self::TrySmaller {
389+
*self
390+
}
379391
}
380392

381393

382394
unsafe impl Dimension for Dim<[Ix; 1]> {
383395
type SliceArg = [Si; 1];
384396
type Pattern = Ix;
397+
type TrySmaller = <Self as RemoveAxis>::Smaller;
385398
#[inline]
386399
fn ndim(&self) -> usize { 1 }
387400
#[inline]
@@ -456,11 +469,16 @@ unsafe impl Dimension for Dim<[Ix; 1]> {
456469
None
457470
}
458471
}
472+
#[inline]
473+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
474+
self.remove_axis(axis)
475+
}
459476
}
460477

461478
unsafe impl Dimension for Dim<[Ix; 2]> {
462479
type SliceArg = [Si; 2];
463480
type Pattern = (Ix, Ix);
481+
type TrySmaller = <Self as RemoveAxis>::Smaller;
464482
#[inline]
465483
fn ndim(&self) -> usize { 2 }
466484
#[inline]
@@ -601,11 +619,16 @@ unsafe impl Dimension for Dim<[Ix; 2]> {
601619
None
602620
}
603621
}
622+
#[inline]
623+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
624+
self.remove_axis(axis)
625+
}
604626
}
605627

606628
unsafe impl Dimension for Dim<[Ix; 3]> {
607629
type SliceArg = [Si; 3];
608630
type Pattern = (Ix, Ix, Ix);
631+
type TrySmaller = <Self as RemoveAxis>::Smaller;
609632
#[inline]
610633
fn ndim(&self) -> usize { 3 }
611634
#[inline]
@@ -681,13 +704,18 @@ unsafe impl Dimension for Dim<[Ix; 3]> {
681704
}
682705
order
683706
}
707+
#[inline]
708+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
709+
self.remove_axis(axis)
710+
}
684711
}
685712

686713
macro_rules! large_dim {
687714
($n:expr, $name:ident, $($ix:ident),+) => (
688715
unsafe impl Dimension for Dim<[Ix; $n]> {
689716
type SliceArg = [Si; $n];
690717
type Pattern = ($($ix,)*);
718+
type TrySmaller = <Self as RemoveAxis>::Smaller;
691719
#[inline]
692720
fn ndim(&self) -> usize { $n }
693721
#[inline]
@@ -698,6 +726,10 @@ macro_rules! large_dim {
698726
fn slice(&self) -> &[Ix] { self.ix() }
699727
#[inline]
700728
fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() }
729+
#[inline]
730+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
731+
self.remove_axis(axis)
732+
}
701733
}
702734
)
703735
}
@@ -712,13 +744,25 @@ unsafe impl Dimension for IxDyn
712744
{
713745
type SliceArg = [Si];
714746
type Pattern = Self;
747+
type TrySmaller = <Self as RemoveAxis>::Smaller;
748+
#[inline]
715749
fn ndim(&self) -> usize { self.ix().len() }
750+
#[inline]
716751
fn slice(&self) -> &[Ix] { self.ix() }
752+
#[inline]
717753
fn slice_mut(&mut self) -> &mut [Ix] { self.ixm() }
718754
#[inline]
719755
fn into_pattern(self) -> Self::Pattern {
720756
self
721757
}
758+
#[inline]
759+
fn try_remove_axis(&self, axis: Axis) -> Self::TrySmaller {
760+
if self.ndim() > 0 {
761+
self.remove_axis(axis)
762+
} else {
763+
self.clone()
764+
}
765+
}
722766
}
723767

724768
impl<J> Index<J> for Dim<Vec<usize>>

src/impl_methods.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ use super::ZipExt;
2323
use dimension::IntoDimension;
2424
use dimension::{axes_of, Axes, merge_axes, stride_offset};
2525
use iterators::whole_chunks_of;
26+
use iterators::{
27+
new_inner_iter_smaller,
28+
new_inner_iter_smaller_mut,
29+
};
2630

2731
use {
2832
NdIndex,
@@ -1138,7 +1142,8 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
11381142
}
11391143
// otherwise, break the arrays up into their inner rows
11401144
let mut try_slices = true;
1141-
let rows = self.inner_iter_mut().zip(rhs.inner_iter());
1145+
let rows = new_inner_iter_smaller_mut(self.view_mut()).zip(
1146+
new_inner_iter_smaller(rhs.view()));
11421147
for (mut s_row, r_row) in rows {
11431148
if try_slices {
11441149
if let Some(self_s) = s_row.as_slice_mut() {

src/iterators/mod.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,55 @@ impl<'a, A, D> ExactSizeIterator for InnerIterMut<'a, A, D>
520520
}
521521
}
522522

523+
/// Create an InnerIter one dimension smaller than D (if possible)
524+
pub fn new_inner_iter_smaller<A, D>(v: ArrayView<A, D>)
525+
-> InnerIter<A, D::TrySmaller>
526+
where D: Dimension
527+
{
528+
let ndim = v.ndim();
529+
let len;
530+
let stride;
531+
let iter_v;
532+
if ndim == 0 {
533+
len = 1;
534+
stride = 0;
535+
iter_v = v.try_remove_axis(Axis(0))
536+
} else {
537+
len = v.dim.last_elem();
538+
stride = v.strides.last_elem() as isize;
539+
iter_v = v.try_remove_axis(Axis(ndim - 1))
540+
}
541+
InnerIter {
542+
inner_len: len,
543+
inner_stride: stride,
544+
iter: iter_v.into_base_iter(),
545+
}
546+
}
547+
548+
pub fn new_inner_iter_smaller_mut<A, D>(v: ArrayViewMut<A, D>)
549+
-> InnerIterMut<A, D::TrySmaller>
550+
where D: Dimension,
551+
{
552+
let ndim = v.ndim();
553+
let len;
554+
let stride;
555+
let iter_v;
556+
if ndim == 0 {
557+
len = 1;
558+
stride = 0;
559+
iter_v = v.try_remove_axis(Axis(0))
560+
} else {
561+
len = v.dim.last_elem();
562+
stride = v.strides.last_elem() as isize;
563+
iter_v = v.try_remove_axis(Axis(ndim - 1))
564+
}
565+
InnerIterMut {
566+
inner_len: len,
567+
inner_stride: stride,
568+
iter: iter_v.into_base_iter(),
569+
}
570+
}
571+
523572
#[derive(Debug)]
524573
pub struct OuterIterCore<A, D> {
525574
index: Ix,

src/lib.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,20 @@ impl<A, S, D> ArrayBase<S, D>
606606
row.into_iter_().fold((), |(), elt| f(elt));
607607
}
608608
}
609+
610+
/// Remove array axis `axis` and return the result.
611+
fn try_remove_axis(self, axis: Axis) -> ArrayBase<S, D::TrySmaller>
612+
{
613+
let d = self.dim.try_remove_axis(axis);
614+
let s = self.strides.try_remove_axis(axis);
615+
ArrayBase {
616+
ptr: self.ptr,
617+
data: self.data,
618+
dim: d,
619+
strides: s,
620+
}
621+
}
622+
609623
}
610624

611625

tests/ix0.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
extern crate ndarray;
3+
4+
use ndarray::Array;
5+
use ndarray::Ix0;
6+
use ndarray::ShapeBuilder;
7+
8+
#[test]
9+
fn test_ix0() {
10+
let mut a = Array::zeros(Ix0());
11+
assert_eq!(a[()], 0.);
12+
a[()] = 1.;
13+
assert_eq!(a[()], 1.);
14+
assert_eq!(a.len(), 1);
15+
assert_eq!(a.as_slice().unwrap(), &[1.]);
16+
17+
let mut a = Array::zeros(Ix0().f());
18+
assert_eq!(a[()], 0.);
19+
a[()] = 1.;
20+
assert_eq!(a[()], 1.);
21+
assert_eq!(a.len(), 1);
22+
assert_eq!(a.as_slice().unwrap(), &[1.]);
23+
}
24+
25+
#[test]
26+
fn test_ix0_add() {
27+
let mut a = Array::zeros(Ix0());
28+
a += 1.;
29+
assert_eq!(a[()], 1.);
30+
a += 2.;
31+
assert_eq!(a[()], 3.);
32+
}
33+
34+
#[test]
35+
fn test_ix0_add_add() {
36+
let mut a = Array::zeros(Ix0());
37+
a += 1.;
38+
let mut b = Array::zeros(Ix0());
39+
b += 1.;
40+
a += &b;
41+
assert_eq!(a[()], 2.);
42+
}
43+
44+
#[test]
45+
fn test_ix0_add_broad() {
46+
let mut b = Array::from_vec(vec![5., 6.]);
47+
let mut a = Array::zeros(Ix0());
48+
a += 1.;
49+
b += &a;
50+
assert_eq!(b[0], 6.);
51+
assert_eq!(b[1], 7.);
52+
}

0 commit comments

Comments
 (0)