@@ -923,15 +923,16 @@ for (fname, elty) in
923923        function  gemm_strided_batched! (transA:: Char ,
924924                               transB:: Char ,
925925                               alpha:: Number ,
926-                                A:: DenseCuArray {$elty, 3} ,
927-                                B:: DenseCuArray {$elty, 3} ,
926+                                A:: AbstractArray {$elty, 3} ,  #  allow PermutedDimsArray 
927+                                B:: AbstractArray {$elty, 3} ,
928928                               beta:: Number ,
929-                                C:: DenseCuArray {$elty, 3} )
929+                                C:: AbstractArray {$elty, 3} )
930930           m =  size (A, transA ==  ' N'   ?  1  :  2 )
931931           k =  size (A, transA ==  ' N'   ?  2  :  1 )
932932           n =  size (B, transB ==  ' N'   ?  2  :  1 )
933933
934-            @assert  size (A, 3 ) ==  size (B, 3 ) ==  size (C, 3 ) " Batch size mismatch" 
934+            @assert  size (A, 3 ) ==  size (C, 3 ) ||  size (A, 3 ) ==  1  " batch size mismatch: A != C" 
935+            @assert  size (B, 3 ) ==  size (C, 3 ) ||  size (B, 3 ) ==  1  " batch size mismatch: B != C" 
935936
936937           if  m !=  size (C,1 ) ||  n !=  size (C,2 ) ||  k !=  size (B, transB ==  ' N'   ?  1  :  2 )
937938               throw (DimensionMismatch (" "  ))
@@ -940,26 +941,26 @@ for (fname, elty) in
940941           ldb =  max (1 ,stride (B,2 ))
941942           ldc =  max (1 ,stride (C,2 ))
942943
943-            strideA =  stride (A, 3 )
944-            strideB =  stride (B, 3 )
944+            strideA =  size (A,  3 )  ==   1   ?   0   :   stride (A, 3 )
945+            strideB =  size (B,  3 )  ==   1   ?   0   :   stride (B, 3 )
945946           strideC =  stride (C, 3 )
946-            batchCount =  size (A , 3 )
947+            batchCount =  size (C , 3 )
947948           $ fname (handle (), transA, transB, m, n, k, alpha, A, lda, strideA, B,
948949                  ldb, strideB, beta, C, ldc, strideC, batchCount)
949950           C
950951        end 
951952        function  gemm_strided_batched (transA:: Char ,
952953                      transB:: Char ,
953954                      alpha:: Number ,
954-                       A:: DenseCuArray {$elty, 3} ,
955-                       B:: DenseCuArray {$elty, 3} )
956-             C =  similar (B, (size (A, transA ==  ' N'   ?  1  :  2 ), size (B, transB ==  ' N'   ?  2  :  1 ), size (A, 3 )))
955+                       A:: AbstractArray {$elty, 3} ,
956+                       B:: AbstractArray {$elty, 3} )
957+             C =  similar (B, (size (A, transA ==  ' N'   ?  1  :  2 ), size (B, transB ==  ' N'   ?  2  :  1 ), max ( size (A, 3 ),  size (B,  3 ) )))
957958            gemm_strided_batched! (transA, transB, alpha, A, B, zero ($ elty), C )
958959        end 
959960        function  gemm_strided_batched (transA:: Char ,
960961                      transB:: Char ,
961-                       A:: DenseCuArray {$elty, 3} ,
962-                       B:: DenseCuArray {$elty, 3} )
962+                       A:: AbstractArray {$elty, 3} ,
963+                       B:: AbstractArray {$elty, 3} )
963964            gemm_strided_batched (transA, transB, one ($ elty), A, B)
964965        end 
965966    end 
0 commit comments