Skip to content

Fix output array for logical ops AND, OR and NEQ #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 55 additions & 33 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,12 @@ where
macro_rules! all_reduce_func_def {
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
#[doc=$doc_str]
pub fn $fn_name<T>(input: &Array<T>)
-> (
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
)
pub fn $fn_name<T>(
input: &Array<T>,
) -> (
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
)
where
T: HasAfEnum,
<T as HasAfEnum>::$assoc_type: HasAfEnum,
Expand All @@ -541,7 +542,9 @@ macro_rules! all_reduce_func_def {
let mut imag: f64 = 0.0;
unsafe {
let err_val = $ffi_name(
&mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
Expand Down Expand Up @@ -676,13 +679,15 @@ macro_rules! all_reduce_func_def2 {
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
where
T: HasAfEnum,
$out_type: HasAfEnum + Fromf64
$out_type: HasAfEnum + Fromf64,
{
let mut real: f64 = 0.0;
let mut imag: f64 = 0.0;
unsafe {
let err_val = $ffi_name(
&mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
&mut real as *mut c_double,
&mut imag as *mut c_double,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
Expand Down Expand Up @@ -869,13 +874,16 @@ macro_rules! dim_ireduce_func_def {
T::$out_type: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut idx: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let mut idx: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array, &mut idx as *mut af_array, input.get(), dim,
&mut temp as *mut af_array,
&mut idx as *mut af_array,
input.get(),
dim,
);
HANDLE_ERROR(AfError::from(err_val));
(temp.into(), idx.into())
(temp.into(), idx.into())
}
}
};
Expand Down Expand Up @@ -910,12 +918,13 @@ dim_ireduce_func_def!("
macro_rules! all_ireduce_func_def {
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
#[doc=$doc_str]
pub fn $fn_name<T>(input: &Array<T>)
-> (
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
u32
)
pub fn $fn_name<T>(
input: &Array<T>,
) -> (
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
u32,
)
where
T: HasAfEnum,
<T as HasAfEnum>::$assoc_type: HasAfEnum,
Expand All @@ -926,8 +935,10 @@ macro_rules! all_ireduce_func_def {
let mut temp: u32 = 0;
unsafe {
let err_val = $ffi_name(
&mut real as *mut c_double, &mut imag as *mut c_double,
&mut temp as *mut c_uint, input.get(),
&mut real as *mut c_double,
&mut imag as *mut c_double,
&mut temp as *mut c_uint,
input.get(),
);
HANDLE_ERROR(AfError::from(err_val));
}
Expand Down Expand Up @@ -1277,23 +1288,28 @@ macro_rules! dim_reduce_by_key_func_def {
/// Tuple of Arrays, with output keys and values after reduction
///
#[doc=$ex_str]
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
dim: i32
pub fn $fn_name<KeyType, ValueType>(
keys: &Array<KeyType>,
vals: &Array<ValueType>,
dim: i32,
) -> (Array<KeyType>, Array<$out_type>)
where
KeyType: ReduceByKeyInput,
ValueType: HasAfEnum,
$out_type: HasAfEnum,
{
unsafe {
let mut out_keys: af_array = std::ptr::null_mut();
let mut out_vals: af_array = std::ptr::null_mut();
let mut out_keys: af_array = std::ptr::null_mut();
let mut out_vals: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut out_keys as *mut af_array, &mut out_vals as *mut af_array,
keys.get(), vals.get(), dim,
&mut out_keys as *mut af_array,
&mut out_vals as *mut af_array,
keys.get(),
vals.get(),
dim,
);
HANDLE_ERROR(AfError::from(err_val));
(out_keys.into(), out_vals.into())
(out_keys.into(), out_vals.into())
}
}
};
Expand Down Expand Up @@ -1408,24 +1424,30 @@ macro_rules! dim_reduce_by_key_nan_func_def {
/// Tuple of Arrays, with output keys and values after reduction
///
#[doc=$ex_str]
pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
dim: i32, replace_value: f64
pub fn $fn_name<KeyType, ValueType>(
keys: &Array<KeyType>,
vals: &Array<ValueType>,
dim: i32,
replace_value: f64,
) -> (Array<KeyType>, Array<$out_type>)
where
KeyType: ReduceByKeyInput,
ValueType: HasAfEnum,
$out_type: HasAfEnum,
{
unsafe {
let mut out_keys: af_array = std::ptr::null_mut();
let mut out_vals: af_array = std::ptr::null_mut();
let mut out_keys: af_array = std::ptr::null_mut();
let mut out_vals: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut out_keys as *mut af_array,
&mut out_vals as *mut af_array,
keys.get(), vals.get(), dim, replace_value,
keys.get(),
vals.get(),
dim,
replace_value,
);
HANDLE_ERROR(AfError::from(err_val));
(out_keys.into(), out_vals.into())
(out_keys.into(), out_vals.into())
}
}
};
Expand Down
59 changes: 28 additions & 31 deletions src/core/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,7 @@ macro_rules! binary_func {
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_fn(
&mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
);
let err_val = $ffi_fn(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
HANDLE_ERROR(AfError::from(err_val));
Into::<Array<A::Output>>::into(temp)
}
Expand All @@ -318,17 +316,6 @@ binary_func!(
bitxor,
af_bitxor
);
binary_func!(
"Elementwise not equals comparison of two Arrays",
neq,
af_neq
);
binary_func!(
"Elementwise logical and operation of two Arrays",
and,
af_and
);
binary_func!("Elementwise logical or operation of two Arrays", or, af_or);
binary_func!(
"Elementwise minimum operation of two Arrays",
minof,
Expand Down Expand Up @@ -404,9 +391,7 @@ macro_rules! overloaded_binary_func {
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
);
let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down Expand Up @@ -499,7 +484,7 @@ overloaded_binary_func!(
overloaded_binary_func!("Compute root", root, root_helper, af_root);
overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);

macro_rules! overloaded_compare_func {
macro_rules! overloaded_logic_func {
($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<bool>
where
Expand All @@ -508,9 +493,7 @@ macro_rules! overloaded_compare_func {
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
);
let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down Expand Up @@ -545,11 +528,7 @@ macro_rules! overloaded_compare_func {
/// - Only one element in `arg1` or `arg2` along a given dimension/axis
///
/// - The trait `Convertable` essentially translates to a scalar native type on rust or Array.
pub fn $fn_name<T, U>(
arg1: &T,
arg2: &U,
batch: bool,
) -> Array<bool>
pub fn $fn_name<T, U>(arg1: &T, arg2: &U, batch: bool) -> Array<bool>
where
T: Convertable,
U: Convertable,
Expand All @@ -573,36 +552,54 @@ macro_rules! overloaded_compare_func {
};
}

overloaded_compare_func!(
overloaded_logic_func!(
"Perform `less than` comparison operation",
lt,
lt_helper,
af_lt
);
overloaded_compare_func!(
overloaded_logic_func!(
"Perform `greater than` comparison operation",
gt,
gt_helper,
af_gt
);
overloaded_compare_func!(
overloaded_logic_func!(
"Perform `less than equals` comparison operation",
le,
le_helper,
af_le
);
overloaded_compare_func!(
overloaded_logic_func!(
"Perform `greater than equals` comparison operation",
ge,
ge_helper,
af_ge
);
overloaded_compare_func!(
overloaded_logic_func!(
"Perform `equals` comparison operation",
eq,
eq_helper,
af_eq
);
overloaded_logic_func!(
"Elementwise `not equals` comparison of two Arrays",
neq,
neq_helper,
af_neq
);
overloaded_logic_func!(
"Elementwise logical AND operation of two Arrays",
and,
and_helper,
af_and
);
overloaded_logic_func!(
"Elementwise logical OR operation of two Arrays",
or,
or_helper,
af_or
);

fn clamp_helper<X, Y>(
inp: &Array<X>,
Expand Down
4 changes: 2 additions & 2 deletions src/core/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ unsafe impl<T: HasAfEnum> Send for Array<T> {}
unsafe impl<T: HasAfEnum> Sync for Array<T> {}

macro_rules! is_func {
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
#[doc=$doc_str]
pub fn $fn_name(&self) -> bool {
unsafe {
Expand All @@ -195,7 +195,7 @@ macro_rules! is_func {
ret_val
}
}
)
};
}

impl<T> Array<T>
Expand Down
6 changes: 3 additions & 3 deletions src/image/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ macro_rules! filt_func_def {
T: HasAfEnum + ImageFilterType,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array,
input.get(),
Expand Down Expand Up @@ -1181,7 +1181,7 @@ macro_rules! grayrgb_func_def {
T: HasAfEnum + GrayRGBConvertible,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(&mut temp as *mut af_array, input.get(), r, g, b);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
Expand All @@ -1201,7 +1201,7 @@ macro_rules! hsvrgb_func_def {
T: HasAfEnum + RealFloating,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(&mut temp as *mut af_array, input.get());
HANDLE_ERROR(AfError::from(err_val));
temp.into()
Expand Down
10 changes: 7 additions & 3 deletions src/signal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ macro_rules! conv_func_def {
F: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array,
signal.get(),
Expand Down Expand Up @@ -796,9 +796,13 @@ macro_rules! fft_conv_func_def {
F: HasAfEnum,
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_name(
&mut temp as *mut af_array, signal.get(), filter.get(), mode as c_uint);
&mut temp as *mut af_array,
signal.get(),
filter.get(),
mode as c_uint,
);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down
2 changes: 1 addition & 1 deletion src/statistics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ macro_rules! stat_wtd_func_def {
{
unsafe {
let mut temp: af_array = std::ptr::null_mut();
let err_val = $ffi_fn(&mut temp as *mut af_array,input.get(), weights.get(), dim);
let err_val = $ffi_fn(&mut temp as *mut af_array, input.get(), weights.get(), dim);
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down
Loading