Skip to content

Commit 0c6f5c6

Browse files
author
Michael Abbott
committed
use _batched_gemm and storage_type from FluxML/NNlib.jl#191
1 parent 8bc792b commit 0c6f5c6

File tree

2 files changed

+14
-114
lines changed

2 files changed

+14
-114
lines changed

src/nnlib.jl

Lines changed: 4 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -32,109 +32,12 @@ end
3232

3333

3434
# Batched matrix multiplication
35+
# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191
3536

36-
# This method has a slightly tighter signature than the one in NNlib, all same eltype.
37-
function NNlib.batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{T,3}, B::AbstractArray{T,3}) where {T<:CUBLAS.CublasFloat}
38-
if is_strided_cu(A) && is_strided_cu(B) && is_strided_cu(C)
39-
# Data is on GPU, and it's safe to call strides(A). gemm_strided_batched may be legal.
40-
batched_try_gemm!(C, A, B)
41-
42-
elseif is_strided_cu(A) || is_strided_cu(B) || is_strided_cu(C)
43-
# This is hopeless, but best option is the fallback
44-
@debug "weird mix of CPU + GPU?"
45-
NNlib.batched_mul_generic!(C, A, B)
46-
47-
else
48-
# All cases for CPU gemm! will come through here, is_strided_cu(A) compiles away:
49-
NNlib.batched_mul_cpu!(C, A, B)
50-
end
51-
end
52-
53-
const batched_gemm_args = [
54-
(:(AbstractArray{T, 3}), 'N', :identity),
55-
(:(NNlib.BatchedTranspose{T}), 'T', :batched_transpose),
56-
(:(NNlib.BatchedAdjoint{T}), 'C', :batched_adjoint)
57-
]
58-
59-
using NNlib: batched_mul!, BatchedTranspose, BatchedAdjoint, batched_transpose, batched_adjoint
60-
using NNlib: _unbatch, _perm12
61-
62-
for (TA, transA, fA) in batched_gemm_args, (TB, transB, fB) in batched_gemm_args
63-
@eval function batched_try_gemm!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}
64-
65-
Abase, Bbase = _unbatch(A), _unbatch(B)
66-
67-
# Best case, we can call batched_gemm! immediately:
68-
if Base.stride(Abase,1) == Base.stride(Bbase,1) == Base.stride(C,1) == 1
69-
CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), Abase, Bbase, zero(T), C)
70-
71-
# Second-best, can we fix it by Perm.ing the base, and adjusing 'T' label?
72-
# But only if we won't produce BatchedTranspose(BatchedAdjoint(complex array)).
73-
elseif Base.stride(Abase,2) == 1 && !(T<:Complex && $TA<:BatchedAdjoint)
74-
newAbase = batched_transpose(_perm12(Abase))
75-
return batched_try_gemm!(C, $fA(newAbase), B)
76-
77-
elseif Base.stride(Bbase,2) == 1 && !(T<:Complex && $TB<:BatchedAdjoint)
78-
newBbase = batched_transpose(_perm12(Bbase))
79-
return batched_try_gemm!(C, A, $fB(newBbase))
80-
81-
# Fallback, e.g when Base.stride(A,3)==1
82-
else
83-
@debug "couldn't re-arrange strides for CUBLAS.gemm_strided_batched!" strides(A) strides(B) strides(C)
84-
NNlib.batched_mul_generic!(C, A, B)
85-
end
86-
C
87-
end
88-
end
89-
90-
91-
# This is obviously the wrong place for this! Not sure where it should go.
92-
# Recursive version, will handle e.g. NamedDimsArray
93-
function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T}
94-
if A === parent(A)
95-
throw(MethodError(Base.unsafe_convert, Tuple{CUDAdrv.CuPtr{T}, typeof(A)}))
96-
else
97-
return Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A))
98-
end
99-
end
100-
37+
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
38+
CuArrays.CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)
10139

10240
# This is https://github.com/JuliaLang/julia/pull/35304, here just for testing now:
10341
Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims)
42+
# @which Base.similar(PermutedDimsArray(rand(2,2), (2,1)), Int, Base.Dims{2}((3,3)))
10443

105-
106-
# Also the wong place for this, surely.
107-
"""
108-
is_strided_cu(A)
109-
110-
This should return `true` for `A::CuArray`, and also for:
111-
* Any `view(::CuArray)` or `reshape(::CuArray)` etc. which remains a `StridedArray`
112-
* Any other wrapper for which `is_strided_cu(parent(A))`
113-
* Except that `Adjoint(A)` is only unwrapped for real numbers.
114-
115-
Such wrappers include `PermutedDimsArray(::CuArray, ...)`,
116-
but also those defined elsewhere (such as `NamedDimsArray`s)
117-
which are assumed not to break strided-ness.
118-
119-
`Transpose` and `Adjoint` don't currently define `strides`, so for now they return `false`.
120-
"""
121-
is_strided_cu(A::CuArray) = true
122-
is_strided_cu(A) = false
123-
function is_strided_cu(A::AbstractArray)
124-
M = parentmodule(typeof(A))
125-
if parent(A) === A # Array, SparseMatrix, StaticArray
126-
false
127-
elseif M === Base || M === Core || M ===LinearAlgebra
128-
A isa StridedArray && is_strided_cu(parent(A))
129-
else
130-
is_strided_cu(parent(A)) # PermutedDimsArray, NamedDimsArray
131-
end
132-
end
133-
134-
if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose})
135-
is_strided_cu(A::LinearAlgebra.Transpose) = is_strided(parent(A))
136-
is_strided_cu(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A))
137-
else
138-
is_strided_cu(A::LinearAlgebra.Transpose) = false
139-
is_strided_cu(A::LinearAlgebra.Adjoint) = false
140-
end

test/nnlib.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,22 @@
1616
@test cu(Ca) batched_mul(cu(A), batched_adjoint(cu(B)))
1717
end
1818

19-
using CuArrays: is_strided_cu
19+
using NNlib: is_strided, are_strided, storage_type
2020
using LinearAlgebra
21-
@testset "is_strided_cu" begin
21+
@testset "NNlib storage_type etc." begin
2222

2323
M = cu(ones(10,10))
2424

25-
@test is_strided_cu(M)
26-
@test is_strided_cu(view(M, 1:2:5,:))
27-
@test is_strided_cu(PermutedDimsArray(M, (2,1)))
25+
@test is_strided(M)
26+
@test is_strided(view(M, 1:2:5,:))
27+
@test is_strided(PermutedDimsArray(M, (2,1)))
2828

29-
@test !is_strided_cu(reshape(view(M, 1:2:10,:), 10,:))
30-
@test !is_strided_cu((M.+im)')
31-
@test !is_strided_cu(ones(10,10))
32-
@test !is_strided_cu(Diagonal(ones(3)))
29+
@test !is_strided(reshape(view(M, 1:2:10,:), 10,:))
30+
@test !is_strided((M.+im)')
31+
@test !is_strided(Diagonal(cu(ones(3))))
3332

34-
#=
35-
using NamedDims
36-
@test is_strided(NamedDimsArray(M,(:a, :b))) # and 0.029 ns, 0 allocations
37-
=#
33+
@test storage_type(M) == CuArray{Float32,2,Nothing}
34+
@test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing}
3835

3936
end
4037

0 commit comments

Comments
 (0)