diff --git a/src/core/index.rs b/src/core/index.rs index 4055152c..17ec7b4a 100644 --- a/src/core/index.rs +++ b/src/core/index.rs @@ -293,13 +293,11 @@ pub fn row(input: &Array, row_num: i64) -> Array where T: HasAfEnum, { - index( - input, - &[ - Seq::new(row_num as f64, row_num as f64, 1.0), - Seq::default(), - ], - ) + let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)]; + for _d in 1..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set `row_num`^th row in `inout` Array to a new Array `new_row` @@ -308,7 +306,7 @@ where T: HasAfEnum, { let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)]; - if inout.dims().ndims() > 1 { + for _d in 1..inout.dims().ndims() { seqs.push(Seq::default()); } assign_seq(inout, &seqs, new_row) @@ -320,10 +318,11 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - index( - input, - &[Seq::new(first as f64, last as f64, step), Seq::default()], - ) + let mut seqs = vec![Seq::new(first as f64, last as f64, step)]; + for _d in 1..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows` @@ -332,7 +331,10 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()]; + let mut seqs = vec![Seq::new(first as f64, last as f64, step)]; + for _d in 1..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_rows) } @@ -352,13 +354,14 @@ pub fn col(input: &Array, col_num: i64) -> Array where T: HasAfEnum, { - index( - input, - &[ - Seq::default(), - Seq::new(col_num as f64, col_num as f64, 1.0), - ], - ) + let mut seqs = vec![ + Seq::default(), + Seq::new(col_num as f64, col_num as f64, 1.0), + ]; + for _d in 2..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set `col_num`^th col in `inout` Array to a new Array `new_col` @@ -366,10 +369,13 @@ pub fn set_col(inout: &mut Array, new_col: &Array, col_num: i64) where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::new(col_num as f64, col_num as f64, 1.0), ]; + for _d in 2..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_col) } @@ -379,10 +385,11 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - index( - input, - &[Seq::default(), Seq::new(first as f64, last as f64, step)], - ) + let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)]; + for _d in 2..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols` @@ -391,7 +398,10 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)]; + let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)]; + for _d in 2..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_cols) } @@ -402,11 +412,14 @@ pub fn slice(input: &Array, slice_num: i64) -> Array where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(slice_num as f64, slice_num as f64, 1.0), ]; + for _d in 3..input.dims().ndims() { + seqs.push(Seq::default()); + } index(input, &seqs) } @@ -417,11 +430,14 @@ pub fn set_slice(inout: &mut Array, new_slice: &Array, slice_num: i64) where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(slice_num as f64, slice_num as f64, 1.0), ]; + for _d in 3..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_slice) } @@ -433,11 +449,14 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(first as f64, last as f64, step), ]; + for _d in 3..input.dims().ndims() { + seqs.push(Seq::default()); + } index(input, &seqs) } @@ -449,11 +468,14 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(first as f64, last as f64, step), ]; + for _d in 3..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_slices) } @@ -655,7 +677,7 @@ mod tests { use super::super::device::set_device; use super::super::dim4::Dim4; use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer}; - use super::super::index::{cols, rows}; + use super::super::index::{cols, rows, set_row, set_rows}; use super::super::random::randu; use super::super::seq::Seq; @@ -868,4 +890,44 @@ mod tests { // 0.9675 0.3712 0.7896 // ANCHOR_END: get_rows } + + #[test] + fn change_row() { + set_device(0); + + let v0: Vec = vec![true, true, true, true, true, true]; + let mut a0 = Array::new(&v0, dim4!(v0.len() as u64)); + + let v1: Vec = vec![false]; + let a1 = Array::new(&v1, dim4!(v1.len() as u64)); + + set_row(&mut a0, &a1, 2); + + let mut res = vec![true; a0.elements()]; + a0.host(&mut res); + + let gold = vec![true, true, false, true, true, true]; + + assert_eq!(gold, res); + } + + #[test] + fn change_rows() { + set_device(0); + + let v0: Vec = vec![true, true, true, true, true, true]; + let mut a0 = Array::new(&v0, dim4!(v0.len() as u64)); + + let v1: Vec = vec![false, false]; + let a1 = Array::new(&v1, dim4!(v1.len() as u64)); + + set_rows(&mut a0, &a1, 2, 3); + + let mut res = vec![true; a0.elements()]; + a0.host(&mut res); + + let gold = vec![true, true, false, false, true, true]; + + assert_eq!(gold, res); + } } diff --git a/src/core/macros.rs b/src/core/macros.rs index d6b0b159..4abba9ee 100644 --- a/src/core/macros.rs +++ b/src/core/macros.rs @@ -190,6 +190,9 @@ macro_rules! view { $( seq_vec.push($crate::seq!($start:$end:$step)); )* + for _d in seq_vec.len()..$array_ident.dims().ndims() { + seq_vec.push($crate::seq!()); + } $crate::index(&$array_ident, &seq_vec) } }; @@ -354,7 +357,7 @@ mod tests { use super::super::array::Array; use super::super::data::constant; use super::super::device::set_device; - use super::super::index::index; + use super::super::index::{index, rows, set_rows}; use super::super::random::randu; #[test] @@ -505,4 +508,55 @@ mod tests { let _ruu32_5x5 = randu!(u32; 5, 5); let _ruu8_5x5 = randu!(u8; 5, 5); } + + #[test] + fn match_eval_macro_with_set_rows() { + set_device(0); + + let inpt = vec![true, true, true, true, true, true, true, true, true, true]; + let gold = vec![ + true, true, false, false, true, true, true, false, false, true, + ]; + + let mut orig_arr = Array::new(&inpt, dim4!(5, 2)); + let mut orig_cln = orig_arr.clone(); + + let new_vals = vec![false, false, false, false]; + let new_arr = Array::new(&new_vals, dim4!(2, 2)); + + eval!( orig_arr[2:3:1,1:1:0] = new_arr ); + let mut res1 = vec![true; orig_arr.elements()]; + orig_arr.host(&mut res1); + + set_rows(&mut orig_cln, &new_arr, 2, 3); + let mut res2 = vec![true; orig_cln.elements()]; + orig_cln.host(&mut res2); + + assert_eq!(gold, res1); + assert_eq!(res1, res2); + } + + #[test] + fn match_view_macro_with_get_rows() { + set_device(0); + + let inpt: Vec = (0..10).collect(); + let gold: Vec = vec![2, 3, 7, 8]; + + println!("input {:?}", inpt); + println!("gold {:?}", gold); + + let orig_arr = Array::new(&inpt, dim4!(5, 2)); + + let view_out = view!( orig_arr[2:3:1] ); + let mut res1 = vec![0i32; view_out.elements()]; + view_out.host(&mut res1); + + let rows_out = rows(&orig_arr, 2, 3); + let mut res2 = vec![0i32; rows_out.elements()]; + rows_out.host(&mut res2); + + assert_eq!(gold, res1); + assert_eq!(res1, res2); + } }