Skip to content

Commit d202256

Browse files
committed
New fns to create arrays using Dim4 and DType params
* create_constant * create_constant_i64 * create_constant_u64 * create_range * create_range * create_iota
1 parent 63293af commit d202256

File tree

3 files changed

+167
-15
lines changed

3 files changed

+167
-15
lines changed

src/arith/mod.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ extern crate num;
33

44
use dim4::Dim4;
55
use array::Array;
6-
use defines::AfError;
6+
use defines::{AfError, DType};
77
use error::HANDLE_ERROR;
88
use self::libc::{c_int};
9-
use data::{constant, constant_like, tile};
9+
use data::{constant, create_constant, create_constant_i64, create_constant_u64, tile};
1010
use self::num::Complex;
1111

1212
use std::ops::Neg;
@@ -493,6 +493,10 @@ impl Neg for Array {
493493
type Output = Array;
494494

495495
fn neg(self) -> Self::Output {
496-
constant_like(0.0, &self) - self
496+
match self.get_type() {
497+
DType::S64 => (create_constant_i64(0, self.dims()) - self),
498+
DType::U64 => (create_constant_u64(0, self.dims()) - self),
499+
_ => (create_constant(0.0, 0.0, self.dims(), self.get_type()) - self)
500+
}
497501
}
498502
}

src/data/mod.rs

Lines changed: 156 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate num;
33

44
use array::Array;
55
use dim4::Dim4;
6-
use defines::AfError;
6+
use defines::{AfError, DType};
77
use error::HANDLE_ERROR;
88
use self::libc::{uint8_t, c_int, c_uint, c_double};
99
use self::num::Complex;
@@ -623,24 +623,170 @@ pub fn replace_scalar(a: &mut Array, cond: &Array, b: f64) {
623623
}
624624
}
625625

626-
/// Create an array filled with given constant retaining type/shape of another Array.
626+
/// Create an Array of particular shape filled with given constant.
627+
///
628+
/// You can use this function to create arrays of the following types only.
629+
/// - DType::F32
630+
/// - DType::C32
631+
/// - DType::F64
632+
/// - DType::C64
633+
/// - DType::B8
634+
/// - DType::S32
635+
/// - DType::U32
636+
/// - DType::U8
637+
/// - DType::S16
638+
/// - DType::U16
639+
///
640+
/// # Parameters
641+
///
642+
/// - `real` is the constant with which output Array is to be filled
643+
/// - `imag` is the constant with which imaginary part of a complex Array will be filed.
644+
/// - `dims` is the size of Array
645+
/// - `dtype` indicates whats the type of the Array to be created
646+
///
647+
/// # Return Values
648+
///
649+
/// Array of `dims` shape and filed with given constant `value`.
650+
pub fn create_constant(real: f64, imag: f64, dims: Dim4, dtype: DType) -> Array {
651+
unsafe {
652+
let mut temp: i64 = 0;
653+
let err_val = match dtype {
654+
DType::S64 => { AfError::ERR_TYPE as i32 },
655+
DType::U64 => { AfError::ERR_TYPE as i32 },
656+
DType::C32 => {
657+
af_constant_complex(&mut temp as MutAfArray,
658+
real as c_double, imag as c_double,
659+
dims.ndims() as c_uint,
660+
dims.get().as_ptr() as *const DimT, 1)
661+
},
662+
DType::C64 => {
663+
af_constant_complex(&mut temp as MutAfArray,
664+
real as c_double, imag as c_double,
665+
dims.ndims() as c_uint,
666+
dims.get().as_ptr() as *const DimT, 3)
667+
},
668+
_ => {
669+
af_constant(&mut temp as MutAfArray, real as c_double,
670+
dims.ndims() as c_uint,
671+
dims.get().as_ptr() as *const DimT,
672+
dtype as c_int)
673+
}
674+
};
675+
HANDLE_ERROR(AfError::from(err_val));
676+
Array::from(temp)
677+
}
678+
}
679+
680+
/// Create an array filled with i64 constant
681+
///
682+
/// # Parameters
683+
///
684+
/// - `value` is the constant with which output Array is to be filled
685+
/// - `dims` is the size of Array
686+
///
687+
/// # Return Values
688+
///
689+
/// Array filled with `value` and `dims` shape.
690+
pub fn create_constant_i64(value: i64, dims: Dim4) -> Array {
691+
unsafe {
692+
let mut temp: i64 = 0;
693+
let err_val = af_constant_long(&mut temp as MutAfArray, value as Intl,
694+
dims.ndims() as c_uint,
695+
dims.get().as_ptr() as *const DimT);
696+
HANDLE_ERROR(AfError::from(err_val));
697+
Array::from(temp)
698+
}
699+
}
700+
701+
/// Create an array filled with u64 constant
627702
///
628703
/// # Parameters
629704
///
630705
/// - `value` is the constant with which output Array is to be filled
631-
/// - `input` is the Array whose shape the output Array has to maintain
706+
/// - `dims` is the size of Array
707+
///
708+
/// # Return Values
709+
///
710+
/// Array filled with `value` and `dims` shape.
711+
pub fn create_constant_u64(value: u64, dims: Dim4) -> Array {
712+
unsafe {
713+
let mut temp: i64 = 0;
714+
let err_val = af_constant_ulong(&mut temp as MutAfArray, value as Uintl,
715+
dims.ndims() as c_uint,
716+
dims.get().as_ptr() as *const DimT);
717+
HANDLE_ERROR(AfError::from(err_val));
718+
Array::from(temp)
719+
}
720+
}
721+
722+
/// Create a Range of values
723+
///
724+
/// Creates an array with [0, n] values along the `seq_dim` which is tiled across other dimensions.
725+
///
726+
/// # Parameters
727+
///
728+
/// - `dims` is the size of Array
729+
/// - `seq_dim` is the dimension along which range values are populated, all values along other
730+
/// dimensions are just repeated
731+
/// - `dtype` indicates whats the type of the Array to be created
732+
///
733+
/// # Return Values
734+
/// Array
735+
#[allow(unused_mut)]
736+
pub fn create_range(dims: Dim4, seq_dim: i32, dtype: DType) -> Array {
737+
unsafe {
738+
let mut temp: i64 = 0;
739+
let err_val = af_range(&mut temp as MutAfArray,
740+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
741+
seq_dim as c_int, dtype as uint8_t);
742+
HANDLE_ERROR(AfError::from(err_val));
743+
Array::from(temp)
744+
}
745+
}
746+
747+
/// Create a range of values
748+
///
749+
/// Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
750+
///
751+
/// # Parameters
752+
///
753+
/// - `dims` is the dimensions of the sequence to be generated
754+
/// - `tdims` is the number of repitions of the unit dimensions
755+
/// - `dtype` indicates whats the type of the Array to be created
632756
///
633757
/// # Return Values
634758
///
635-
/// Array with given constant value and input Array's shape and similar internal data type.
636-
pub fn constant_like(value: f64, input: &Array) -> Array {
637-
let dims = input.dims();
759+
/// Array
760+
#[allow(unused_mut)]
761+
pub fn create_iota(dims: Dim4, tdims: Dim4, dtype: DType) -> Array {
762+
unsafe {
763+
let mut temp: i64 = 0;
764+
let err_val =af_iota(&mut temp as MutAfArray,
765+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
766+
tdims.ndims() as c_uint, tdims.get().as_ptr() as *const DimT,
767+
dtype as uint8_t);
768+
HANDLE_ERROR(AfError::from(err_val));
769+
Array::from(temp)
770+
}
771+
}
772+
773+
/// Create an identity array with 1's in diagonal
774+
///
775+
/// # Parameters
776+
///
777+
/// - `dims` is the output Array dimensions
778+
/// - `dtype` indicates whats the type of the Array to be created
779+
///
780+
/// # Return Values
781+
///
782+
/// Identity matrix
783+
#[allow(unused_mut)]
784+
pub fn create_identity(dims: Dim4, dtype: DType) -> Array {
638785
unsafe {
639786
let mut temp: i64 = 0;
640-
let err_val = af_constant(&mut temp as MutAfArray, value as c_double,
641-
dims.ndims() as c_uint,
642-
dims.get().as_ptr() as *const DimT,
643-
input.get_type() as c_int);
787+
let err_val = af_identity(&mut temp as MutAfArray,
788+
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
789+
dtype as uint8_t);
644790
HANDLE_ERROR(AfError::from(err_val));
645791
Array::from(temp)
646792
}

src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ mod backend;
3232
pub use blas::{matmul, dot, transpose, transpose_inplace};
3333
mod blas;
3434

35-
pub use data::{constant, constant_like, range, iota};
36-
pub use data::{identity, diag_create, diag_extract, lower, upper};
35+
pub use data::{constant, range, iota, identity};
36+
pub use data::{diag_create, diag_extract, lower, upper};
3737
pub use data::{join, join_many, tile};
3838
pub use data::{reorder, shift, moddims, flat, flip};
3939
pub use data::{select, selectl, selectr, replace, replace_scalar};
40+
pub use data::{create_constant, create_range, create_iota, create_identity};
41+
pub use data::{create_constant_i64, create_constant_u64};
4042
mod data;
4143

4244
pub use device::{get_version, info, init, device_count, is_double_available, set_device, get_device};

0 commit comments

Comments
 (0)