|
32 | 32 |
|
33 | 33 |
|
34 | 34 | # Batched matrix multiplication |
| 35 | +# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 |
35 | 36 |
|
36 | | -# This method has a slightly tighter signature than the one in NNlib, all same eltype. |
37 | | -function NNlib.batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{T,3}, B::AbstractArray{T,3}) where {T<:CUBLAS.CublasFloat} |
38 | | - if is_strided_cu(A) && is_strided_cu(B) && is_strided_cu(C) |
39 | | - # Data is on GPU, and it's safe to call strides(A). gemm_strided_batched may be legal. |
40 | | - batched_try_gemm!(C, A, B) |
41 | | - |
42 | | - elseif is_strided_cu(A) || is_strided_cu(B) || is_strided_cu(C) |
43 | | - # This is hopeless, but best option is the fallback |
44 | | - @debug "weird mix of CPU + GPU?" |
45 | | - NNlib.batched_mul_generic!(C, A, B) |
46 | | - |
47 | | - else |
48 | | - # All cases for CPU gemm! will come through here, is_strided_cu(A) compiles away: |
49 | | - NNlib.batched_mul_cpu!(C, A, B) |
50 | | - end |
51 | | -end |
52 | | - |
53 | | -const batched_gemm_args = [ |
54 | | - (:(AbstractArray{T, 3}), 'N', :identity), |
55 | | - (:(NNlib.BatchedTranspose{T}), 'T', :batched_transpose), |
56 | | - (:(NNlib.BatchedAdjoint{T}), 'C', :batched_adjoint) |
57 | | -] |
58 | | - |
59 | | -using NNlib: batched_mul!, BatchedTranspose, BatchedAdjoint, batched_transpose, batched_adjoint |
60 | | -using NNlib: _unbatch, _perm12 |
61 | | - |
62 | | -for (TA, transA, fA) in batched_gemm_args, (TB, transB, fB) in batched_gemm_args |
63 | | - @eval function batched_try_gemm!(C::AbstractArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat} |
64 | | - |
65 | | - Abase, Bbase = _unbatch(A), _unbatch(B) |
66 | | - |
67 | | - # Best case, we can call batched_gemm! immediately: |
68 | | - if Base.stride(Abase,1) == Base.stride(Bbase,1) == Base.stride(C,1) == 1 |
69 | | - CuArrays.CUBLAS.gemm_strided_batched!($transA, $transB, one(T), Abase, Bbase, zero(T), C) |
70 | | - |
71 | | - # Second-best, can we fix it by Perm.ing the base, and adjusing 'T' label? |
72 | | - # But only if we won't produce BatchedTranspose(BatchedAdjoint(complex array)). |
73 | | - elseif Base.stride(Abase,2) == 1 && !(T<:Complex && $TA<:BatchedAdjoint) |
74 | | - newAbase = batched_transpose(_perm12(Abase)) |
75 | | - return batched_try_gemm!(C, $fA(newAbase), B) |
76 | | - |
77 | | - elseif Base.stride(Bbase,2) == 1 && !(T<:Complex && $TB<:BatchedAdjoint) |
78 | | - newBbase = batched_transpose(_perm12(Bbase)) |
79 | | - return batched_try_gemm!(C, A, $fB(newBbase)) |
80 | | - |
81 | | - # Fallback, e.g when Base.stride(A,3)==1 |
82 | | - else |
83 | | - @debug "couldn't re-arrange strides for CUBLAS.gemm_strided_batched!" strides(A) strides(B) strides(C) |
84 | | - NNlib.batched_mul_generic!(C, A, B) |
85 | | - end |
86 | | - C |
87 | | - end |
88 | | -end |
89 | | - |
90 | | - |
91 | | -# This is obviously the wrong place for this! Not sure where it should go. |
92 | | -# Recursive version, will handle e.g. NamedDimsArray |
93 | | -function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T} |
94 | | - if A === parent(A) |
95 | | - throw(MethodError(Base.unsafe_convert, Tuple{CUDAdrv.CuPtr{T}, typeof(A)})) |
96 | | - else |
97 | | - return Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A)) |
98 | | - end |
99 | | -end |
100 | | - |
| 37 | +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = |
| 38 | + CuArrays.CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) |
101 | 39 |
|
102 | 40 | # This is https://github.com/JuliaLang/julia/pull/35304, here just for testing now: |
103 | 41 | Base.similar(A::PermutedDimsArray, T::Type, dims::Base.Dims) = similar(parent(A), T, dims) |
| 42 | +# @which Base.similar(PermutedDimsArray(rand(2,2), (2,1)), Int, Base.Dims{2}((3,3))) |
104 | 43 |
|
105 | | - |
106 | | -# Also the wong place for this, surely. |
107 | | -""" |
108 | | - is_strided_cu(A) |
109 | | -
|
110 | | -This should return `true` for `A::CuArray`, and also for: |
111 | | -* Any `view(::CuArray)` or `reshape(::CuArray)` etc. which remains a `StridedArray` |
112 | | -* Any other wrapper for which `is_strided_cu(parent(A))` |
113 | | -* Except that `Adjoint(A)` is only unwrapped for real numbers. |
114 | | -
|
115 | | -Such wrappers include `PermutedDimsArray(::CuArray, ...)`, |
116 | | -but also those defined elsewhere (such as `NamedDimsArray`s) |
117 | | -which are assumed not to break strided-ness. |
118 | | -
|
119 | | -`Transpose` and `Adjoint` don't currently define `strides`, so for now they return `false`. |
120 | | -""" |
121 | | -is_strided_cu(A::CuArray) = true |
122 | | -is_strided_cu(A) = false |
123 | | -function is_strided_cu(A::AbstractArray) |
124 | | - M = parentmodule(typeof(A)) |
125 | | - if parent(A) === A # Array, SparseMatrix, StaticArray |
126 | | - false |
127 | | - elseif M === Base || M === Core || M ===LinearAlgebra |
128 | | - A isa StridedArray && is_strided_cu(parent(A)) |
129 | | - else |
130 | | - is_strided_cu(parent(A)) # PermutedDimsArray, NamedDimsArray |
131 | | - end |
132 | | -end |
133 | | - |
134 | | -if hasmethod(Base.strides, Tuple{LinearAlgebra.Transpose}) |
135 | | - is_strided_cu(A::LinearAlgebra.Transpose) = is_strided(parent(A)) |
136 | | - is_strided_cu(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) |
137 | | -else |
138 | | - is_strided_cu(A::LinearAlgebra.Transpose) = false |
139 | | - is_strided_cu(A::LinearAlgebra.Adjoint) = false |
140 | | -end |
0 commit comments