From 92d3d7897c9b810f62c75fb41af85f2cdeb594df Mon Sep 17 00:00:00 2001 From: pradeep Date: Sat, 10 Apr 2021 18:17:30 +0530 Subject: [PATCH 1/2] Formating fixes after latest rust update --- src/algorithm/mod.rs | 88 +++++++++++++++++++++++++++---------------- src/core/arith.rs | 18 ++------- src/core/array.rs | 4 +- src/image/mod.rs | 6 +-- src/signal/mod.rs | 10 +++-- src/statistics/mod.rs | 2 +- src/vision/mod.rs | 4 +- 7 files changed, 74 insertions(+), 58 deletions(-) diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 1adc0153d..e7fe2a783 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -527,11 +527,12 @@ where macro_rules! all_reduce_func_def { ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) - -> ( - <::$assoc_type as HasAfEnum>::BaseType, - <::$assoc_type as HasAfEnum>::BaseType - ) + pub fn $fn_name( + input: &Array, + ) -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType, + ) where T: HasAfEnum, ::$assoc_type: HasAfEnum, @@ -541,7 +542,9 @@ macro_rules! all_reduce_func_def { let mut imag: f64 = 0.0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -676,13 +679,15 @@ macro_rules! all_reduce_func_def2 { pub fn $fn_name(input: &Array) -> ($out_type, $out_type) where T: HasAfEnum, - $out_type: HasAfEnum + Fromf64 + $out_type: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -869,13 +874,16 @@ macro_rules! dim_ireduce_func_def { T::$out_type: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); - let mut idx: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); + let mut idx: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut temp as *mut af_array, &mut idx as *mut af_array, input.get(), dim, + &mut temp as *mut af_array, + &mut idx as *mut af_array, + input.get(), + dim, ); HANDLE_ERROR(AfError::from(err_val)); - (temp.into(), idx.into()) + (temp.into(), idx.into()) } } }; @@ -910,12 +918,13 @@ dim_ireduce_func_def!(" macro_rules! all_ireduce_func_def { ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) - -> ( - <::$assoc_type as HasAfEnum>::BaseType, - <::$assoc_type as HasAfEnum>::BaseType, - u32 - ) + pub fn $fn_name( + input: &Array, + ) -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType, + u32, + ) where T: HasAfEnum, ::$assoc_type: HasAfEnum, @@ -926,8 +935,10 @@ macro_rules! all_ireduce_func_def { let mut temp: u32 = 0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, - &mut temp as *mut c_uint, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + &mut temp as *mut c_uint, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -1277,8 +1288,10 @@ macro_rules! dim_reduce_by_key_func_def { /// Tuple of Arrays, with output keys and values after reduction /// #[doc=$ex_str] - pub fn $fn_name(keys: &Array, vals: &Array, - dim: i32 + pub fn $fn_name( + keys: &Array, + vals: &Array, + dim: i32, ) -> (Array, Array<$out_type>) where KeyType: ReduceByKeyInput, @@ -1286,14 +1299,17 @@ macro_rules! dim_reduce_by_key_func_def { $out_type: HasAfEnum, { unsafe { - let mut out_keys: af_array = std::ptr::null_mut(); - let mut out_vals: af_array = std::ptr::null_mut(); + let mut out_keys: af_array = std::ptr::null_mut(); + let mut out_vals: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut out_keys as *mut af_array, &mut out_vals as *mut af_array, - keys.get(), vals.get(), dim, + &mut out_keys as *mut af_array, + &mut out_vals as *mut af_array, + keys.get(), + vals.get(), + dim, ); HANDLE_ERROR(AfError::from(err_val)); - (out_keys.into(), out_vals.into()) + (out_keys.into(), out_vals.into()) } } }; @@ -1408,8 +1424,11 @@ macro_rules! dim_reduce_by_key_nan_func_def { /// Tuple of Arrays, with output keys and values after reduction /// #[doc=$ex_str] - pub fn $fn_name(keys: &Array, vals: &Array, - dim: i32, replace_value: f64 + pub fn $fn_name( + keys: &Array, + vals: &Array, + dim: i32, + replace_value: f64, ) -> (Array, Array<$out_type>) where KeyType: ReduceByKeyInput, @@ -1417,15 +1436,18 @@ macro_rules! dim_reduce_by_key_nan_func_def { $out_type: HasAfEnum, { unsafe { - let mut out_keys: af_array = std::ptr::null_mut(); - let mut out_vals: af_array = std::ptr::null_mut(); + let mut out_keys: af_array = std::ptr::null_mut(); + let mut out_vals: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut out_keys as *mut af_array, &mut out_vals as *mut af_array, - keys.get(), vals.get(), dim, replace_value, + keys.get(), + vals.get(), + dim, + replace_value, ); HANDLE_ERROR(AfError::from(err_val)); - (out_keys.into(), out_vals.into()) + (out_keys.into(), out_vals.into()) } } }; diff --git a/src/core/arith.rs b/src/core/arith.rs index 3c978eef8..b0e26ee3c 100644 --- a/src/core/arith.rs +++ b/src/core/arith.rs @@ -293,9 +293,7 @@ macro_rules! binary_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_fn( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_fn(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); Into::>::into(temp) } @@ -404,9 +402,7 @@ macro_rules! overloaded_binary_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_name( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); temp.into() } @@ -508,9 +504,7 @@ macro_rules! overloaded_compare_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_name( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); temp.into() } @@ -545,11 +539,7 @@ macro_rules! overloaded_compare_func { /// - Only one element in `arg1` or `arg2` along a given dimension/axis /// /// - The trait `Convertable` essentially translates to a scalar native type on rust or Array. - pub fn $fn_name( - arg1: &T, - arg2: &U, - batch: bool, - ) -> Array + pub fn $fn_name(arg1: &T, arg2: &U, batch: bool) -> Array where T: Convertable, U: Convertable, diff --git a/src/core/array.rs b/src/core/array.rs index 083ed4798..0e741a30a 100644 --- a/src/core/array.rs +++ b/src/core/array.rs @@ -185,7 +185,7 @@ unsafe impl Send for Array {} unsafe impl Sync for Array {} macro_rules! is_func { - ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => ( + ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => { #[doc=$doc_str] pub fn $fn_name(&self) -> bool { unsafe { @@ -195,7 +195,7 @@ macro_rules! is_func { ret_val } } - ) + }; } impl Array diff --git a/src/image/mod.rs b/src/image/mod.rs index 68aee1992..b1a694708 100644 --- a/src/image/mod.rs +++ b/src/image/mod.rs @@ -963,7 +963,7 @@ macro_rules! filt_func_def { T: HasAfEnum + ImageFilterType, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut temp as *mut af_array, input.get(), @@ -1181,7 +1181,7 @@ macro_rules! grayrgb_func_def { T: HasAfEnum + GrayRGBConvertible, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name(&mut temp as *mut af_array, input.get(), r, g, b); HANDLE_ERROR(AfError::from(err_val)); temp.into() @@ -1201,7 +1201,7 @@ macro_rules! hsvrgb_func_def { T: HasAfEnum + RealFloating, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name(&mut temp as *mut af_array, input.get()); HANDLE_ERROR(AfError::from(err_val)); temp.into() diff --git a/src/signal/mod.rs b/src/signal/mod.rs index e9eb27feb..6933c92c3 100644 --- a/src/signal/mod.rs +++ b/src/signal/mod.rs @@ -721,7 +721,7 @@ macro_rules! conv_func_def { F: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut temp as *mut af_array, signal.get(), @@ -796,9 +796,13 @@ macro_rules! fft_conv_func_def { F: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut temp as *mut af_array, signal.get(), filter.get(), mode as c_uint); + &mut temp as *mut af_array, + signal.get(), + filter.get(), + mode as c_uint, + ); HANDLE_ERROR(AfError::from(err_val)); temp.into() } diff --git a/src/statistics/mod.rs b/src/statistics/mod.rs index 550add48d..d79d280d2 100644 --- a/src/statistics/mod.rs +++ b/src/statistics/mod.rs @@ -142,7 +142,7 @@ macro_rules! stat_wtd_func_def { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_fn(&mut temp as *mut af_array,input.get(), weights.get(), dim); + let err_val = $ffi_fn(&mut temp as *mut af_array, input.get(), weights.get(), dim); HANDLE_ERROR(AfError::from(err_val)); temp.into() } diff --git a/src/vision/mod.rs b/src/vision/mod.rs index 7422386c4..6c6840a0e 100644 --- a/src/vision/mod.rs +++ b/src/vision/mod.rs @@ -131,7 +131,7 @@ unsafe impl Send for Features {} unsafe impl Sync for Features {} macro_rules! feat_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => ( + ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => { #[doc=$doc_str] pub fn $fn_name(&self) -> Array { unsafe { @@ -146,7 +146,7 @@ macro_rules! feat_func_def { retained } } - ) + }; } impl Features { From c2754e0104869230e652b8081966a8aeac273619 Mon Sep 17 00:00:00 2001 From: pradeep Date: Mon, 12 Apr 2021 10:14:27 +0530 Subject: [PATCH 2/2] Fix output array for logical ops AND, OR and NEQ Earlier, these functions were using binary_func macro which picks output based on the input arrays types. However, as these are logical operations which always result in boolean output, the correct macro to use for this is overloaded_logic_func. With this fix, ::and, ::or and ::neq functions have acquired an additional feature i.e. the user can also use scalars as one of the inputs of these functions. --- src/core/arith.rs | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/src/core/arith.rs b/src/core/arith.rs index b0e26ee3c..a0e7e7322 100644 --- a/src/core/arith.rs +++ b/src/core/arith.rs @@ -316,17 +316,6 @@ binary_func!( bitxor, af_bitxor ); -binary_func!( - "Elementwise not equals comparison of two Arrays", - neq, - af_neq -); -binary_func!( - "Elementwise logical and operation of two Arrays", - and, - af_and -); -binary_func!("Elementwise logical or operation of two Arrays", or, af_or); binary_func!( "Elementwise minimum operation of two Arrays", minof, @@ -495,7 +484,7 @@ overloaded_binary_func!( overloaded_binary_func!("Compute root", root, root_helper, af_root); overloaded_binary_func!("Computer power", pow, pow_helper, af_pow); -macro_rules! overloaded_compare_func { +macro_rules! overloaded_logic_func { ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => { fn $help_name(lhs: &Array, rhs: &Array, batch: bool) -> Array where @@ -563,36 +552,54 @@ macro_rules! overloaded_compare_func { }; } -overloaded_compare_func!( +overloaded_logic_func!( "Perform `less than` comparison operation", lt, lt_helper, af_lt ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `greater than` comparison operation", gt, gt_helper, af_gt ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `less than equals` comparison operation", le, le_helper, af_le ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `greater than equals` comparison operation", ge, ge_helper, af_ge ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `equals` comparison operation", eq, eq_helper, af_eq ); +overloaded_logic_func!( + "Elementwise `not equals` comparison of two Arrays", + neq, + neq_helper, + af_neq +); +overloaded_logic_func!( + "Elementwise logical AND operation of two Arrays", + and, + and_helper, + af_and +); +overloaded_logic_func!( + "Elementwise logical OR operation of two Arrays", + or, + or_helper, + af_or +); fn clamp_helper( inp: &Array,