Skip to content

Commit 0d76526

Browse files
committed
Use axpy for scaled_add
1 parent 5e6a6b4 commit 0d76526

File tree

5 files changed

+124
-2
lines changed

5 files changed

+124
-2
lines changed

benches/bench1.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![feature(test)]
22
#![allow(unused_imports)]
33

4+
extern crate blas_sys;
45
extern crate test;
56
#[macro_use(s)]
67
extern crate ndarray;
@@ -468,6 +469,18 @@ fn scaled_add_2d_f32_regular(bench: &mut test::Bencher)
468469
});
469470
}
470471

472+
#[bench]
473+
fn scaled_add_2d_f32_axpy(bench: &mut test::Bencher)
474+
{
475+
let mut av = Array::<f32, _>::zeros((64, 64));
476+
let bv = Array::<f32, _>::zeros((64, 64));
477+
let scalar = 3.1415926535;
478+
bench.iter(|| {
479+
av.scaled_add_axpy(scalar, &bv);
480+
});
481+
}
482+
483+
471484
#[bench]
472485
fn assign_scalar_2d_corder(bench: &mut test::Bencher)
473486
{

src/dimension/dimension_trait.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
280280
///
281281
/// Returns `Some(n)` if the strides in all dimensions are equispaced. Returns `None` if not.
282282
#[doc(hidden)]
283-
fn equispaced_stride(dim: &Self, strides: &Self) -> Option<usize> {
283+
fn equispaced_stride(dim: &Self, strides: &Self) -> Option<isize> {
284284
let order = strides._fastest_varying_stride_order();
285285
let base_stride = strides[order[0]];
286286

@@ -295,7 +295,10 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
295295
}
296296
next_stride *= dim_slice[i];
297297
}
298-
Some(base_stride)
298+
299+
unsafe {
300+
Some(::std::ptr::read(&base_stride as *const _ as *const isize))
301+
}
299302
}
300303

301304
/// Return the axis ordering corresponding to the fastest variation

src/impl_methods.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
7373
self.dim.clone()
7474
}
7575

76+
/// Return the strides of the array as they are stored in the array.
77+
pub fn raw_strides(&self) -> D {
78+
self.strides.clone()
79+
}
80+
7681
/// Return the shape of the array as a slice.
7782
pub fn shape(&self) -> &[Ix] {
7883
self.dim.slice()

src/linalg/impl_linalg.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,62 @@ impl<A, S, D> ArrayBase<S, D>
295295
{
296296
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
297297
}
298+
299+
/// Perform the operation `self += alpha * rhs` efficiently, where
300+
/// `alpha` is a scalar and `rhs` is another array. This operation is
301+
/// also known as `axpy` in BLAS.
302+
///
303+
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
304+
///
305+
/// **Panics** if broadcasting isn’t possible.
306+
#[cfg(feature="blas")]
307+
pub fn scaled_add_axpy<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
308+
where S: DataMut,
309+
S2: Data<Elem=A>,
310+
A: LinalgScalar + ::std::fmt::Debug,
311+
E: Dimension,
312+
{
313+
debug_assert_eq!(self.len(), rhs.len());
314+
assert!(self.len() == rhs.len());
315+
316+
{
317+
macro_rules! axpy {
318+
($ty:ty, $func:ident) => {{
319+
if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) {
320+
let strides = self.raw_strides();
321+
let order = Dimension::_fastest_varying_stride_order(&strides);
322+
let incx = self.strides()[order[0]];
323+
324+
let strides = rhs.raw_strides();
325+
let order = Dimension::_fastest_varying_stride_order(&strides);
326+
let incy = self.strides()[order[0]];
327+
328+
unsafe {
329+
let (lhs_ptr, n, incx) = blas_1d_params(self.ptr,
330+
self.len(),
331+
incx);
332+
let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr,
333+
rhs.len(),
334+
incy);
335+
blas_sys::c::$func(
336+
n,
337+
cast_as(&alpha),
338+
rhs_ptr as *const $ty,
339+
incy,
340+
lhs_ptr as *mut $ty,
341+
incx);
342+
return;
343+
}
344+
}
345+
}}
346+
}
347+
348+
axpy!{f32, cblas_saxpy};
349+
axpy!{f64, cblas_daxpy};
350+
}
351+
352+
self.scaled_add(alpha, rhs);
353+
}
298354
}
299355

300356
// mat_mul_impl uses ArrayView arguments to send all array kinds into
@@ -531,6 +587,35 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
531587
true
532588
}
533589

590+
#[cfg(feature="blas")]
591+
fn blas_compat<A, S, D>(a: &ArrayBase<S, D>) -> bool
592+
where S: Data,
593+
A: 'static,
594+
S::Elem: 'static,
595+
D: Dimension,
596+
{
597+
if !same_type::<A, S::Elem>() {
598+
return false;
599+
}
600+
601+
if a.len() > blas_index::max_value() as usize {
602+
return false;
603+
}
604+
605+
match D::equispaced_stride(&a.raw_dim(), &a.raw_strides()) {
606+
Some(stride) => {
607+
if stride > blas_index::max_value() as isize ||
608+
stride < blas_index::min_value() as isize {
609+
return false;
610+
}
611+
},
612+
None => {
613+
return false;
614+
}
615+
}
616+
true
617+
}
618+
534619
#[cfg(feature="blas")]
535620
fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
536621
where S: Data,

tests/oper.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,22 @@ fn scaled_add() {
459459

460460
}
461461

462+
#[cfg(blas)]
463+
#[test]
464+
fn scaled_add_axpy() {
465+
let a = range_mat(16, 15);
466+
let mut b = range_mat(16, 15);
467+
b.mapv_inplace(f32::exp);
468+
469+
let alpha = 0.2_f32;
470+
let mut c = a.clone();
471+
c.scaled_add(alpha, &b);
472+
473+
let mut d = a.clone();
474+
d.scaled_add_axpy(alpha, &b);
475+
assert_eq!(c, d);
476+
}
477+
462478
#[test]
463479
fn gen_mat_mul() {
464480
let alpha = -2.3;

0 commit comments

Comments
 (0)