From 1442f0f207a3f7b02a2abd3b86da09d93c892c07 Mon Sep 17 00:00:00 2001 From: pradeep Date: Thu, 10 Dec 2020 11:20:23 +0530 Subject: [PATCH] Call set_device in tests for valid cuda context in threaded runs --- src/algorithm/mod.rs | 3 +++ src/core/data.rs | 3 +++ src/core/index.rs | 18 ++++++++++++++---- src/core/macros.rs | 12 +++++++++++- 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 3d23530c4..9e4d400d0 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -1444,10 +1444,12 @@ dim_reduce_by_key_nan_func_def!( mod tests { use super::super::core::c32; use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all}; + use crate::core::set_device; use crate::randu; #[test] fn all_reduce_api() { + set_device(0); let a = randu!(c32; 10, 10); println!("Reduction of complex f32 matrix: {:?}", sum_all(&a)); @@ -1469,6 +1471,7 @@ mod tests { #[test] fn all_ireduce_api() { + set_device(0); let a = randu!(c32; 10); println!("Reduction of complex f32 matrix: {:?}", imin_all(&a)); diff --git a/src/core/data.rs b/src/core/data.rs index 0fa0f0b94..26de918e1 100644 --- a/src/core/data.rs +++ b/src/core/data.rs @@ -962,6 +962,7 @@ mod tests { use super::reorder_v2; use super::super::defines::BorderType; + use super::super::device::set_device; use super::super::random::randu; use super::pad; @@ -969,6 +970,7 @@ mod tests { #[test] fn check_reorder_api() { + set_device(0); let a = randu::(dim4!(4, 5, 2, 3)); let _transposed = reorder_v2(&a, 1, 0, None); @@ -979,6 +981,7 @@ mod tests { #[test] fn check_pad_api() { + set_device(0); let a = randu::(dim4![3, 3]); let begin_dims = dim4!(0, 0, 0, 0); let end_dims = dim4!(2, 2, 0, 0); diff --git a/src/core/index.rs b/src/core/index.rs index 641e9735b..157598971 100644 --- a/src/core/index.rs +++ b/src/core/index.rs @@ -655,6 +655,7 @@ impl SeqInternal { mod tests { use super::super::array::Array; use super::super::data::constant; + use super::super::device::set_device; use super::super::dim4::Dim4; use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer}; use super::super::index::{cols, rows}; @@ -665,6 +666,7 @@ mod tests { #[test] fn non_macro_seq_index() { + set_device(0); // ANCHOR: non_macro_seq_index let dims = Dim4::new(&[5, 5, 1, 1]); let a = randu::(dims); @@ -690,6 +692,7 @@ mod tests { #[test] fn seq_index() { + set_device(0); // ANCHOR: seq_index let dims = dim4!(5, 5, 1, 1); let a = randu::(dims); @@ -701,8 +704,9 @@ mod tests { #[test] fn non_macro_seq_assign() { + set_device(0); // ANCHOR: non_macro_seq_assign - let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1])); + let mut a = constant(2.0 as f32, dim4!(5, 3)); //print(&a); // 2.0 2.0 2.0 // 2.0 2.0 2.0 @@ -710,9 +714,9 @@ mod tests { // 2.0 2.0 2.0 // 2.0 2.0 2.0 - let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1])); - let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()]; - assign_seq(&mut a, seqs, &b); + let b = constant(1.0 as f32, dim4!(3, 3)); + let seqs = [seq!(1:3:1), seq!()]; + assign_seq(&mut a, &seqs, &b); //print(&a); // 2.0 2.0 2.0 // 1.0 1.0 1.0 @@ -724,6 +728,7 @@ mod tests { #[test] fn non_macro_seq_array_index() { + set_device(0); // ANCHOR: non_macro_seq_array_index let values: [f32; 3] = [1.0, 2.0, 3.0]; let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1])); @@ -751,6 +756,7 @@ mod tests { #[test] fn seq_array_index() { + set_device(0); // ANCHOR: seq_array_index let values: [f32; 3] = [1.0, 2.0, 3.0]; let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1])); @@ -762,6 +768,7 @@ mod tests { #[test] fn non_macro_seq_array_assign() { + set_device(0); // ANCHOR: non_macro_seq_array_assign let values: [f32; 3] = [1.0, 2.0, 3.0]; let indices = Array::new(&values, dim4!(3, 1, 1, 1)); @@ -793,6 +800,7 @@ mod tests { #[test] fn setrow() { + set_device(0); // ANCHOR: setrow let a = randu::(dim4!(5, 5, 1, 1)); //print(&a); @@ -817,6 +825,7 @@ mod tests { #[test] fn get_row() { + set_device(0); // ANCHOR: get_row let a = randu::(dim4!(5, 5)); // [5 5 1 1] @@ -840,6 +849,7 @@ mod tests { #[test] fn get_rows() { + set_device(0); // ANCHOR: get_rows let a = randu::(dim4!(5, 5)); // [5 5 1 1] diff --git a/src/core/macros.rs b/src/core/macros.rs index c06018a1f..d6b0b1596 100644 --- a/src/core/macros.rs +++ b/src/core/macros.rs @@ -353,6 +353,7 @@ macro_rules! randn { mod tests { use super::super::array::Array; use super::super::data::constant; + use super::super::device::set_device; use super::super::index::index; use super::super::random::randu; @@ -377,6 +378,7 @@ mod tests { #[test] fn seq_view() { + set_device(0); let mut dim4d = dim4!(5, 3, 2, 1); dim4d[2] = 1; @@ -387,14 +389,17 @@ mod tests { #[test] fn seq_view2() { + set_device(0); // ANCHOR: seq_view2 let a = randu::(dim4!(5, 5)); let _sub = view!(a[1:3:1, 1:1:0]); // 1:1:0 means all elements along axis - // ANCHOR_END: seq_view2 + + // ANCHOR_END: seq_view2 } #[test] fn view_macro() { + set_device(0); let dims = dim4!(5, 5, 2, 1); let a = randu::(dims); let b = a.clone(); @@ -421,6 +426,7 @@ mod tests { #[test] fn eval_assign_seq_indexed_array() { + set_device(0); let dims = dim4!(5, 5); let mut a = randu::(dims); //print(&a); @@ -456,6 +462,7 @@ mod tests { #[test] fn eval_assign_array_to_seqd_array() { + set_device(0); // ANCHOR: macro_seq_assign let mut a = randu::(dim4!(5, 5)); let b = randu::(dim4!(2, 2)); @@ -465,6 +472,7 @@ mod tests { #[test] fn macro_seq_array_assign() { + set_device(0); // ANCHOR: macro_seq_array_assign let values: [f32; 3] = [1.0, 2.0, 3.0]; let indices = Array::new(&values, dim4!(3)); @@ -479,6 +487,7 @@ mod tests { #[test] fn constant_macro() { + set_device(0); let _zeros_1d = constant!(0.0f32; 10); let _zeros_2d = constant!(0.0f64; 5, 5); let _ones_3d = constant!(1u32; 3, 3, 3); @@ -490,6 +499,7 @@ mod tests { #[test] fn rand_macro() { + set_device(0); let _ru5x5 = randu!(5, 5); let _rn5x5 = randn!(5, 5); let _ruu32_5x5 = randu!(u32; 5, 5);