Skip to content
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "a8180fd1445e31c0b1add98dae8da694ac2c23fd"
deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "1ae42464fea5258fd2ff49f1c4a40fc41cba3860"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.6"
version = "0.7.7"

[[OrderedCollections]]
git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ GPUArrays = "6.1.0"
GPUCompiler = "0.8.1"
LLVM = "3"
MacroTools = "0.5"
NNlib = "0.6.5, 0.7"
NNlib = "0.7.7"
Reexport = "0.2"
Requires = "0.5, 1.0"
TimerOutputs = "0.5"
Expand Down
25 changes: 13 additions & 12 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,16 @@ for (fname, elty) in
function gemm_strided_batched!(transA::Char,
transB::Char,
alpha::Number,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3},
A::AbstractArray{$elty, 3}, # allow PermutedDimsArray
B::AbstractArray{$elty, 3},
beta::Number,
C::DenseCuArray{$elty, 3})
C::AbstractArray{$elty, 3})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)

@assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch"
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"

if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2)
throw(DimensionMismatch(""))
Expand All @@ -940,26 +941,26 @@ for (fname, elty) in
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))

strideA = stride(A, 3)
strideB = stride(B, 3)
strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(A, 3)
batchCount = size(C, 3)
$fname(handle(), transA, transB, m, n, k, alpha, A, lda, strideA, B,
ldb, strideB, beta, C, ldc, strideC, batchCount)
C
end
function gemm_strided_batched(transA::Char,
transB::Char,
alpha::Number,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3)))
A::AbstractArray{$elty, 3},
B::AbstractArray{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C )
end
function gemm_strided_batched(transA::Char,
transB::Char,
A::DenseCuArray{$elty, 3},
B::DenseCuArray{$elty, 3})
A::AbstractArray{$elty, 3},
B::AbstractArray{$elty, 3})
gemm_strided_batched(transA, transB, one($elty), A, B)
end
end
Expand Down
6 changes: 6 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,12 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{
end


## PermutedDimsArray

Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} =
Base.unsafe_convert(CuPtr{T}, parent(A))


## reshape

# optimize reshape to return a CuArray
Expand Down
17 changes: 5 additions & 12 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@ end


# Batched matrix multiplication
# 1st argument is produced by NNlib.storage_type(A)
NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) =
CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C)

const batched_gemm_args = [
(:(CuArray{T, 3}), 'N'),
(:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'),
(:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C')
]

for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args
@eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat}
CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C)
C
end
end
Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =
Base.unsafe_convert(CuPtr{T}, parent(A))
43 changes: 42 additions & 1 deletion test/nnlib.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using NNlib

@testset "batched_mul" begin
using NNlib: batched_mul, batched_adjoint, batched_transpose
using NNlib: batched_mul, batched_mul!, batched_vec, batched_adjoint, batched_transpose

A = randn(Float32, 3,3,2);
B = randn(Float32, 3,3,2);
Expand All @@ -14,6 +14,47 @@ using NNlib

Ca = batched_mul(A, batched_adjoint(B))
@test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B)))

# 5-arg batched_mul!
C .= pi
batched_mul!(C, A, B, 2f0, 3f0)
cuCpi = CuArray(similar(C)) .= pi
@test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0)

# PermutedDimsArray
@test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B))

D = permutedims(B, (1,3,2))
Cp = batched_mul(batched_adjoint(A), B)
@test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2)))

# Methods which reshape
M = randn(Float32, 3,3)

Cm = batched_mul(A, M)
@test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M))

Cv = batched_vec(permutedims(A,(3,1,2)), M)
@test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M))
end

@testset "NNlib storage_type etc." begin
using LinearAlgebra
using NNlib: is_strided, are_strided, storage_type

M = cu(ones(10,10))

@test is_strided(M)
@test is_strided(view(M, 1:2:5,:))
@test is_strided(PermutedDimsArray(M, (2,1)))

@test !is_strided(reshape(view(M, 1:2:10,:), 10,:))
@test !is_strided((M .+ im)')
@test !is_strided(Diagonal(cu(ones(3))))

@test storage_type(M) == CuArray{Float32,2}
@test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2}

end

@testset "Broadcast Fix" begin
Expand Down