@@ -295,6 +295,62 @@ impl<A, S, D> ArrayBase<S, D>
295
295
{
296
296
self . zip_mut_with ( rhs, move |y, & x| * y = * y + ( alpha * x) ) ;
297
297
}
298
+
299
+ /// Perform the operation `self += alpha * rhs` efficiently, where
300
+ /// `alpha` is a scalar and `rhs` is another array. This operation is
301
+ /// also known as `axpy` in BLAS.
302
+ ///
303
+ /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
304
+ ///
305
+ /// **Panics** if broadcasting isn’t possible.
306
+ #[ cfg( feature="blas" ) ]
307
+ pub fn scaled_add_axpy < S2 , E > ( & mut self , alpha : A , rhs : & ArrayBase < S2 , E > )
308
+ where S : DataMut ,
309
+ S2 : Data < Elem =A > ,
310
+ A : LinalgScalar + :: std:: fmt:: Debug ,
311
+ E : Dimension ,
312
+ {
313
+ debug_assert_eq ! ( self . len( ) , rhs. len( ) ) ;
314
+ assert ! ( self . len( ) == rhs. len( ) ) ;
315
+
316
+ {
317
+ macro_rules! axpy {
318
+ ( $ty: ty, $func: ident) => { {
319
+ if blas_compat:: <$ty, _, _>( self ) && blas_compat:: <$ty, _, _>( rhs) {
320
+ let strides = self . raw_strides( ) ;
321
+ let order = Dimension :: _fastest_varying_stride_order( & strides) ;
322
+ let incx = self . strides( ) [ order[ 0 ] ] ;
323
+
324
+ let strides = rhs. raw_strides( ) ;
325
+ let order = Dimension :: _fastest_varying_stride_order( & strides) ;
326
+ let incy = self . strides( ) [ order[ 0 ] ] ;
327
+
328
+ unsafe {
329
+ let ( lhs_ptr, n, incx) = blas_1d_params( self . ptr,
330
+ self . len( ) ,
331
+ incx) ;
332
+ let ( rhs_ptr, _, incy) = blas_1d_params( rhs. ptr,
333
+ rhs. len( ) ,
334
+ incy) ;
335
+ blas_sys:: c:: $func(
336
+ n,
337
+ cast_as( & alpha) ,
338
+ rhs_ptr as * const $ty,
339
+ incy,
340
+ lhs_ptr as * mut $ty,
341
+ incx) ;
342
+ return ;
343
+ }
344
+ }
345
+ } }
346
+ }
347
+
348
+ axpy ! { f32 , cblas_saxpy} ;
349
+ axpy ! { f64 , cblas_daxpy} ;
350
+ }
351
+
352
+ self . scaled_add ( alpha, rhs) ;
353
+ }
298
354
}
299
355
300
356
// mat_mul_impl uses ArrayView arguments to send all array kinds into
@@ -531,6 +587,35 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
531
587
true
532
588
}
533
589
590
+ #[ cfg( feature="blas" ) ]
591
+ fn blas_compat < A , S , D > ( a : & ArrayBase < S , D > ) -> bool
592
+ where S : Data ,
593
+ A : ' static ,
594
+ S :: Elem : ' static ,
595
+ D : Dimension ,
596
+ {
597
+ if !same_type :: < A , S :: Elem > ( ) {
598
+ return false ;
599
+ }
600
+
601
+ if a. len ( ) > blas_index:: max_value ( ) as usize {
602
+ return false ;
603
+ }
604
+
605
+ match D :: equispaced_stride ( & a. raw_dim ( ) , & a. raw_strides ( ) ) {
606
+ Some ( stride) => {
607
+ if stride > blas_index:: max_value ( ) as isize ||
608
+ stride < blas_index:: min_value ( ) as isize {
609
+ return false ;
610
+ }
611
+ } ,
612
+ None => {
613
+ return false ;
614
+ }
615
+ }
616
+ true
617
+ }
618
+
534
619
#[ cfg( feature="blas" ) ]
535
620
fn blas_row_major_2d < A , S > ( a : & ArrayBase < S , Ix2 > ) -> bool
536
621
where S : Data ,
0 commit comments