Skip to content

Commit 6566185

Browse files
committed
Use axpy for scaled_add
1 parent 5e6a6b4 commit 6566185

File tree

4 files changed

+117
-2
lines changed

4 files changed

+117
-2
lines changed

benches/bench1.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![feature(test)]
22
#![allow(unused_imports)]
33

4+
extern crate blas_sys;
45
extern crate test;
56
#[macro_use(s)]
67
extern crate ndarray;
@@ -468,6 +469,18 @@ fn scaled_add_2d_f32_regular(bench: &mut test::Bencher)
468469
});
469470
}
470471

472+
#[bench]
473+
fn scaled_add_2d_f32_axpy(bench: &mut test::Bencher)
474+
{
475+
let mut av = Array::<f32, _>::zeros((64, 64));
476+
let bv = Array::<f32, _>::zeros((64, 64));
477+
let scalar = 3.1415926535;
478+
bench.iter(|| {
479+
av.scaled_add_axpy(scalar, &bv);
480+
});
481+
}
482+
483+
471484
#[bench]
472485
fn assign_scalar_2d_corder(bench: &mut test::Bencher)
473486
{

src/dimension/dimension_trait.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
280280
///
281281
/// Returns `Some(n)` if the strides in all dimensions are equispaced. Returns `None` if not.
282282
#[doc(hidden)]
283-
fn equispaced_stride(dim: &Self, strides: &Self) -> Option<usize> {
283+
fn equispaced_stride(dim: &Self, strides: &Self) -> Option<isize> {
284284
let order = strides._fastest_varying_stride_order();
285285
let base_stride = strides[order[0]];
286286

@@ -295,7 +295,10 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
295295
}
296296
next_stride *= dim_slice[i];
297297
}
298-
Some(base_stride)
298+
299+
unsafe {
300+
Some(::std::ptr::read(&base_stride as *const _ as *const isize))
301+
}
299302
}
300303

301304
/// Return the axis ordering corresponding to the fastest variation

src/linalg/impl_linalg.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,60 @@ impl<A, S, D> ArrayBase<S, D>
295295
{
296296
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
297297
}
298+
299+
/// Perform the operation `self += alpha * rhs` efficiently, where
300+
/// `alpha` is a scalar and `rhs` is another array. This operation is
301+
/// also known as `axpy` in BLAS.
302+
///
303+
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
304+
///
305+
/// **Panics** if broadcasting isn’t possible.
306+
#[cfg(feature="blas")]
307+
pub fn scaled_add_axpy<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
308+
where S: DataMut,
309+
S2: Data<Elem=A>,
310+
A: LinalgScalar + ::std::fmt::Debug,
311+
E: Dimension,
312+
{
313+
debug_assert_eq!(self.len(), rhs.len());
314+
assert!(self.len() == rhs.len());
315+
316+
{
317+
macro_rules! axpy {
318+
($ty:ty, $func:ident) => {{
319+
if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) {
320+
let order = Dimension::_fastest_varying_stride_order(&self.strides);
321+
let incx = self.strides()[order[0]];
322+
323+
let order = Dimension::_fastest_varying_stride_order(&rhs.strides);
324+
let incy = self.strides()[order[0]];
325+
326+
unsafe {
327+
let (lhs_ptr, n, incx) = blas_1d_params(self.ptr,
328+
self.len(),
329+
incx);
330+
let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr,
331+
rhs.len(),
332+
incy);
333+
blas_sys::c::$func(
334+
n,
335+
cast_as(&alpha),
336+
rhs_ptr as *const $ty,
337+
incy,
338+
lhs_ptr as *mut $ty,
339+
incx);
340+
return;
341+
}
342+
}
343+
}}
344+
}
345+
346+
axpy!{f32, cblas_saxpy};
347+
axpy!{f64, cblas_daxpy};
348+
}
349+
350+
self.scaled_add(alpha, rhs);
351+
}
298352
}
299353

300354
// mat_mul_impl uses ArrayView arguments to send all array kinds into
@@ -531,6 +585,35 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
531585
true
532586
}
533587

588+
#[cfg(feature="blas")]
589+
fn blas_compat<A, S, D>(a: &ArrayBase<S, D>) -> bool
590+
where S: Data,
591+
A: 'static,
592+
S::Elem: 'static,
593+
D: Dimension,
594+
{
595+
if !same_type::<A, S::Elem>() {
596+
return false;
597+
}
598+
599+
if a.len() > blas_index::max_value() as usize {
600+
return false;
601+
}
602+
603+
match D::equispaced_stride(&a.raw_dim(), &a.strides) {
604+
Some(stride) => {
605+
if stride > blas_index::max_value() as isize ||
606+
stride < blas_index::min_value() as isize {
607+
return false;
608+
}
609+
},
610+
None => {
611+
return false;
612+
}
613+
}
614+
true
615+
}
616+
534617
#[cfg(feature="blas")]
535618
fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
536619
where S: Data,

tests/oper.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,22 @@ fn scaled_add() {
459459

460460
}
461461

462+
#[cfg(blas)]
463+
#[test]
464+
fn scaled_add_axpy() {
465+
let a = range_mat(16, 15);
466+
let mut b = range_mat(16, 15);
467+
b.mapv_inplace(f32::exp);
468+
469+
let alpha = 0.2_f32;
470+
let mut c = a.clone();
471+
c.scaled_add(alpha, &b);
472+
473+
let mut d = a.clone();
474+
d.scaled_add_axpy(alpha, &b);
475+
assert_eq!(c, d);
476+
}
477+
462478
#[test]
463479
fn gen_mat_mul() {
464480
let alpha = -2.3;

0 commit comments

Comments
 (0)