diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 1adc0153d..e7fe2a783 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -527,11 +527,12 @@ where macro_rules! all_reduce_func_def { ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) - -> ( - <::$assoc_type as HasAfEnum>::BaseType, - <::$assoc_type as HasAfEnum>::BaseType - ) + pub fn $fn_name( + input: &Array, + ) -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType, + ) where T: HasAfEnum, ::$assoc_type: HasAfEnum, @@ -541,7 +542,9 @@ macro_rules! all_reduce_func_def { let mut imag: f64 = 0.0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -676,13 +679,15 @@ macro_rules! all_reduce_func_def2 { pub fn $fn_name(input: &Array) -> ($out_type, $out_type) where T: HasAfEnum, - $out_type: HasAfEnum + Fromf64 + $out_type: HasAfEnum + Fromf64, { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -869,13 +874,16 @@ macro_rules! dim_ireduce_func_def { T::$out_type: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); - let mut idx: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); + let mut idx: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut temp as *mut af_array, &mut idx as *mut af_array, input.get(), dim, + &mut temp as *mut af_array, + &mut idx as *mut af_array, + input.get(), + dim, ); HANDLE_ERROR(AfError::from(err_val)); - (temp.into(), idx.into()) + (temp.into(), idx.into()) } } }; @@ -910,12 +918,13 @@ dim_ireduce_func_def!(" macro_rules! all_ireduce_func_def { ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) - -> ( - <::$assoc_type as HasAfEnum>::BaseType, - <::$assoc_type as HasAfEnum>::BaseType, - u32 - ) + pub fn $fn_name( + input: &Array, + ) -> ( + <::$assoc_type as HasAfEnum>::BaseType, + <::$assoc_type as HasAfEnum>::BaseType, + u32, + ) where T: HasAfEnum, ::$assoc_type: HasAfEnum, @@ -926,8 +935,10 @@ macro_rules! all_ireduce_func_def { let mut temp: u32 = 0; unsafe { let err_val = $ffi_name( - &mut real as *mut c_double, &mut imag as *mut c_double, - &mut temp as *mut c_uint, input.get(), + &mut real as *mut c_double, + &mut imag as *mut c_double, + &mut temp as *mut c_uint, + input.get(), ); HANDLE_ERROR(AfError::from(err_val)); } @@ -1277,8 +1288,10 @@ macro_rules! dim_reduce_by_key_func_def { /// Tuple of Arrays, with output keys and values after reduction /// #[doc=$ex_str] - pub fn $fn_name(keys: &Array, vals: &Array, - dim: i32 + pub fn $fn_name( + keys: &Array, + vals: &Array, + dim: i32, ) -> (Array, Array<$out_type>) where KeyType: ReduceByKeyInput, @@ -1286,14 +1299,17 @@ macro_rules! dim_reduce_by_key_func_def { $out_type: HasAfEnum, { unsafe { - let mut out_keys: af_array = std::ptr::null_mut(); - let mut out_vals: af_array = std::ptr::null_mut(); + let mut out_keys: af_array = std::ptr::null_mut(); + let mut out_vals: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut out_keys as *mut af_array, &mut out_vals as *mut af_array, - keys.get(), vals.get(), dim, + &mut out_keys as *mut af_array, + &mut out_vals as *mut af_array, + keys.get(), + vals.get(), + dim, ); HANDLE_ERROR(AfError::from(err_val)); - (out_keys.into(), out_vals.into()) + (out_keys.into(), out_vals.into()) } } }; @@ -1408,8 +1424,11 @@ macro_rules! dim_reduce_by_key_nan_func_def { /// Tuple of Arrays, with output keys and values after reduction /// #[doc=$ex_str] - pub fn $fn_name(keys: &Array, vals: &Array, - dim: i32, replace_value: f64 + pub fn $fn_name( + keys: &Array, + vals: &Array, + dim: i32, + replace_value: f64, ) -> (Array, Array<$out_type>) where KeyType: ReduceByKeyInput, @@ -1417,15 +1436,18 @@ macro_rules! dim_reduce_by_key_nan_func_def { $out_type: HasAfEnum, { unsafe { - let mut out_keys: af_array = std::ptr::null_mut(); - let mut out_vals: af_array = std::ptr::null_mut(); + let mut out_keys: af_array = std::ptr::null_mut(); + let mut out_vals: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut out_keys as *mut af_array, &mut out_vals as *mut af_array, - keys.get(), vals.get(), dim, replace_value, + keys.get(), + vals.get(), + dim, + replace_value, ); HANDLE_ERROR(AfError::from(err_val)); - (out_keys.into(), out_vals.into()) + (out_keys.into(), out_vals.into()) } } }; diff --git a/src/core/arith.rs b/src/core/arith.rs index 3c978eef8..a0e7e7322 100644 --- a/src/core/arith.rs +++ b/src/core/arith.rs @@ -293,9 +293,7 @@ macro_rules! binary_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_fn( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_fn(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); Into::>::into(temp) } @@ -318,17 +316,6 @@ binary_func!( bitxor, af_bitxor ); -binary_func!( - "Elementwise not equals comparison of two Arrays", - neq, - af_neq -); -binary_func!( - "Elementwise logical and operation of two Arrays", - and, - af_and -); -binary_func!("Elementwise logical or operation of two Arrays", or, af_or); binary_func!( "Elementwise minimum operation of two Arrays", minof, @@ -404,9 +391,7 @@ macro_rules! overloaded_binary_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_name( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); temp.into() } @@ -499,7 +484,7 @@ overloaded_binary_func!( overloaded_binary_func!("Compute root", root, root_helper, af_root); overloaded_binary_func!("Computer power", pow, pow_helper, af_pow); -macro_rules! overloaded_compare_func { +macro_rules! overloaded_logic_func { ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => { fn $help_name(lhs: &Array, rhs: &Array, batch: bool) -> Array where @@ -508,9 +493,7 @@ macro_rules! overloaded_compare_func { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_name( - &mut temp as *mut af_array, lhs.get(), rhs.get(), batch, - ); + let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch); HANDLE_ERROR(AfError::from(err_val)); temp.into() } @@ -545,11 +528,7 @@ macro_rules! overloaded_compare_func { /// - Only one element in `arg1` or `arg2` along a given dimension/axis /// /// - The trait `Convertable` essentially translates to a scalar native type on rust or Array. - pub fn $fn_name( - arg1: &T, - arg2: &U, - batch: bool, - ) -> Array + pub fn $fn_name(arg1: &T, arg2: &U, batch: bool) -> Array where T: Convertable, U: Convertable, @@ -573,36 +552,54 @@ macro_rules! overloaded_compare_func { }; } -overloaded_compare_func!( +overloaded_logic_func!( "Perform `less than` comparison operation", lt, lt_helper, af_lt ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `greater than` comparison operation", gt, gt_helper, af_gt ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `less than equals` comparison operation", le, le_helper, af_le ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `greater than equals` comparison operation", ge, ge_helper, af_ge ); -overloaded_compare_func!( +overloaded_logic_func!( "Perform `equals` comparison operation", eq, eq_helper, af_eq ); +overloaded_logic_func!( + "Elementwise `not equals` comparison of two Arrays", + neq, + neq_helper, + af_neq +); +overloaded_logic_func!( + "Elementwise logical AND operation of two Arrays", + and, + and_helper, + af_and +); +overloaded_logic_func!( + "Elementwise logical OR operation of two Arrays", + or, + or_helper, + af_or +); fn clamp_helper( inp: &Array, diff --git a/src/core/array.rs b/src/core/array.rs index 083ed4798..0e741a30a 100644 --- a/src/core/array.rs +++ b/src/core/array.rs @@ -185,7 +185,7 @@ unsafe impl Send for Array {} unsafe impl Sync for Array {} macro_rules! is_func { - ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => ( + ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => { #[doc=$doc_str] pub fn $fn_name(&self) -> bool { unsafe { @@ -195,7 +195,7 @@ macro_rules! is_func { ret_val } } - ) + }; } impl Array diff --git a/src/image/mod.rs b/src/image/mod.rs index 68aee1992..b1a694708 100644 --- a/src/image/mod.rs +++ b/src/image/mod.rs @@ -963,7 +963,7 @@ macro_rules! filt_func_def { T: HasAfEnum + ImageFilterType, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut temp as *mut af_array, input.get(), @@ -1181,7 +1181,7 @@ macro_rules! grayrgb_func_def { T: HasAfEnum + GrayRGBConvertible, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name(&mut temp as *mut af_array, input.get(), r, g, b); HANDLE_ERROR(AfError::from(err_val)); temp.into() @@ -1201,7 +1201,7 @@ macro_rules! hsvrgb_func_def { T: HasAfEnum + RealFloating, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name(&mut temp as *mut af_array, input.get()); HANDLE_ERROR(AfError::from(err_val)); temp.into() diff --git a/src/signal/mod.rs b/src/signal/mod.rs index e9eb27feb..6933c92c3 100644 --- a/src/signal/mod.rs +++ b/src/signal/mod.rs @@ -721,7 +721,7 @@ macro_rules! conv_func_def { F: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( &mut temp as *mut af_array, signal.get(), @@ -796,9 +796,13 @@ macro_rules! fft_conv_func_def { F: HasAfEnum, { unsafe { - let mut temp: af_array = std::ptr::null_mut(); + let mut temp: af_array = std::ptr::null_mut(); let err_val = $ffi_name( - &mut temp as *mut af_array, signal.get(), filter.get(), mode as c_uint); + &mut temp as *mut af_array, + signal.get(), + filter.get(), + mode as c_uint, + ); HANDLE_ERROR(AfError::from(err_val)); temp.into() } diff --git a/src/statistics/mod.rs b/src/statistics/mod.rs index 550add48d..d79d280d2 100644 --- a/src/statistics/mod.rs +++ b/src/statistics/mod.rs @@ -142,7 +142,7 @@ macro_rules! stat_wtd_func_def { { unsafe { let mut temp: af_array = std::ptr::null_mut(); - let err_val = $ffi_fn(&mut temp as *mut af_array,input.get(), weights.get(), dim); + let err_val = $ffi_fn(&mut temp as *mut af_array, input.get(), weights.get(), dim); HANDLE_ERROR(AfError::from(err_val)); temp.into() } diff --git a/src/vision/mod.rs b/src/vision/mod.rs index 7422386c4..6c6840a0e 100644 --- a/src/vision/mod.rs +++ b/src/vision/mod.rs @@ -131,7 +131,7 @@ unsafe impl Send for Features {} unsafe impl Sync for Features {} macro_rules! feat_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => ( + ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => { #[doc=$doc_str] pub fn $fn_name(&self) -> Array { unsafe { @@ -146,7 +146,7 @@ macro_rules! feat_func_def { retained } } - ) + }; } impl Features {