From 1e83dd4b228a83c32d6034a78bafc9bb1bb6f24a Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Fri, 30 Nov 2018 22:56:54 -0800 Subject: [PATCH] Merging code handling Adjoint and Transpose --- stdlib/LinearAlgebra/src/adjtrans.jl | 31 +++++++ stdlib/LinearAlgebra/src/diagonal.jl | 41 +++------ stdlib/LinearAlgebra/src/triangular.jl | 119 +++++-------------------- stdlib/SparseArrays/src/linalg.jl | 37 ++------ 4 files changed, 72 insertions(+), 156 deletions(-) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index c980e500adde0..2549351aaead9 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -71,6 +71,37 @@ struct Transpose{T,S} <: AbstractMatrix{T} end end +""" + functor(::AbstractArray) -> adjoint|transpose|identity + functor(::Type{<:AbstractArray}) -> adjoint|transpose|identity + +Return [`adjoint`](@ref) from an `Adjoint` type or object and +[`transpose`](@ref) from an `Transpose` type or object. Otherwise, +return [`identity`](@ref). Note that `Adjoint` and `Transpose` have +to be the outer-most wrapper object for non-`identity` function to be +returned. +""" +functor(::T) where {T <: AbstractArray} = functor(T) +functor(::Type{<:AbstractArray}) = identity +functor(::Type{<:Adjoint}) = adjoint +functor(::Type{<:Transpose}) = transpose + +""" + inplace(f) -> f! + +Return an in-place variant of function `f`. + +# Examples +```jldoctest +julia> using LinearAlgebra: inplace + +julia> inplace(adjoint) === adjoint! +true +``` +""" +inplace(::typeof(adjoint)) = adjoint! +inplace(::typeof(transpose)) = transpose! + function checkeltype_adjoint(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype} Expected = Base.promote_op(adjoint, ParentEltype) ResultEltype === Expected || error(string( diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index cd17474407d6f..31114638f1e8e 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -213,43 +213,30 @@ function lmul!(D::Diagonal, B::UnitUpperTriangular) UpperTriangular(B.data) end -*(D::Adjoint{<:Any,<:Diagonal}, B::Diagonal) = Diagonal(adjoint.(D.parent.diag) .* B.diag) -*(A::Adjoint{<:Any,<:AbstractTriangular}, D::Diagonal) = rmul!(copy(A), D) -function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal) +*(D::AdjOrTrans{<:Any,<:Diagonal}, B::Diagonal) = + Diagonal(functor(D).(D.parent.diag) .* B.diag) +*(A::AdjOrTrans{<:Any,<:AbstractTriangular}, D::Diagonal) = + rmul!(copy(A), D) +function *(adjA::AdjOrTrans{<:Any,<:AbstractMatrix}, D::Diagonal) A = adjA.parent Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1))) - adjoint!(Ac, A) + inplace(functor(adjA))(Ac, A) rmul!(Ac, D) end -*(D::Transpose{<:Any,<:Diagonal}, B::Diagonal) = Diagonal(transpose.(D.parent.diag) .* B.diag) -*(A::Transpose{<:Any,<:AbstractTriangular}, D::Diagonal) = rmul!(copy(A), D) -function *(transA::Transpose{<:Any,<:AbstractMatrix}, D::Diagonal) - A = transA.parent - At = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1))) - transpose!(At, A) - rmul!(At, D) -end - -*(D::Diagonal, B::Adjoint{<:Any,<:Diagonal}) = Diagonal(D.diag .* adjoint.(B.parent.diag)) -*(D::Diagonal, B::Adjoint{<:Any,<:AbstractTriangular}) = lmul!(D, collect(B)) -*(D::Diagonal, adjQ::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) = (Q = adjQ.parent; rmul!(Array(D), adjoint(Q))) -function *(D::Diagonal, adjA::Adjoint{<:Any,<:AbstractMatrix}) +*(D::Diagonal, B::AdjOrTrans{<:Any,<:Diagonal}) = + Diagonal(D.diag .* functor(B).(B.parent.diag)) +*(D::Diagonal, B::AdjOrTrans{<:Any,<:AbstractTriangular}) = + lmul!(D, copy(B)) +*(D::Diagonal, adjQ::AdjOrTrans{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) = + (Q = adjQ.parent; rmul!(Array(D), functor(adjQ)(Q))) +function *(D::Diagonal, adjA::AdjOrTrans{<:Any,<:AbstractMatrix}) A = adjA.parent Ac = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1))) - adjoint!(Ac, A) + inplace(functor(adjA))(Ac, A) lmul!(D, Ac) end -*(D::Diagonal, B::Transpose{<:Any,<:Diagonal}) = Diagonal(D.diag .* transpose.(B.parent.diag)) -*(D::Diagonal, B::Transpose{<:Any,<:AbstractTriangular}) = lmul!(D, copy(B)) -function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix}) - A = transA.parent - At = similar(A, promote_op(*, eltype(A), eltype(D.diag)), (size(A, 2), size(A, 1))) - transpose!(At, A) - lmul!(D, At) -end - *(D::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:Diagonal}) = Diagonal(adjoint.(D.parent.diag) .* adjoint.(B.parent.diag)) *(D::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:Diagonal}) = diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 903e8954319ee..5b7dc21ef05ee 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -852,17 +852,18 @@ function lmul!(A::UnitLowerTriangular, B::StridedVecOrMat) B end -function lmul!(adjA::Adjoint{<:Any,<:UpperTriangular}, B::StridedVecOrMat) +function lmul!(adjA::AdjOrTrans{<:Any,<:UpperTriangular}, B::StridedVecOrMat) A = adjA.parent + f = functor(adjA) m, n = size(B, 1), size(B, 2) if m != size(A, 1) throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) end for j = 1:n for i = m:-1:1 - Bij = A.data[i,i]'B[i,j] + Bij = f(A.data[i,i]) * B[i,j] for k = 1:i - 1 - Bij += A.data[k,i]'B[k,j] + Bij += f(A.data[k,i]) * B[k,j] end B[i,j] = Bij end @@ -870,8 +871,9 @@ function lmul!(adjA::Adjoint{<:Any,<:UpperTriangular}, B::StridedVecOrMat) B end -function lmul!(adjA::Adjoint{<:Any,<:UnitUpperTriangular}, B::StridedVecOrMat) +function lmul!(adjA::AdjOrTrans{<:Any,<:UnitUpperTriangular}, B::StridedVecOrMat) A = adjA.parent + f = functor(adjA) m, n = size(B, 1), size(B, 2) if m != size(A, 1) throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) @@ -880,7 +882,7 @@ function lmul!(adjA::Adjoint{<:Any,<:UnitUpperTriangular}, B::StridedVecOrMat) for i = m:-1:1 Bij = B[i,j] for k = 1:i - 1 - Bij += A.data[k,i]'B[k,j] + Bij += f(A.data[k,i]) * B[k,j] end B[i,j] = Bij end @@ -888,25 +890,27 @@ function lmul!(adjA::Adjoint{<:Any,<:UnitUpperTriangular}, B::StridedVecOrMat) B end -function lmul!(adjA::Adjoint{<:Any,<:LowerTriangular}, B::StridedVecOrMat) +function lmul!(adjA::AdjOrTrans{<:Any,<:LowerTriangular}, B::StridedVecOrMat) A = adjA.parent + f = functor(adjA) m, n = size(B, 1), size(B, 2) if m != size(A, 1) throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) end for j = 1:n for i = 1:m - Bij = A.data[i,i]'B[i,j] + Bij = f(A.data[i,i]) * B[i,j] for k = i + 1:m - Bij += A.data[k,i]'B[k,j] + Bij += f(A.data[k,i]) * B[k,j] end B[i,j] = Bij end end B end -function lmul!(adjA::Adjoint{<:Any,<:UnitLowerTriangular}, B::StridedVecOrMat) +function lmul!(adjA::AdjOrTrans{<:Any,<:UnitLowerTriangular}, B::StridedVecOrMat) A = adjA.parent + f = functor(adjA) m, n = size(B, 1), size(B, 2) if m != size(A, 1) throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) @@ -915,7 +919,7 @@ function lmul!(adjA::Adjoint{<:Any,<:UnitLowerTriangular}, B::StridedVecOrMat) for i = 1:m Bij = B[i,j] for k = i + 1:m - Bij += A.data[k,i]'B[k,j] + Bij += f(A.data[k,i]) * B[k,j] end B[i,j] = Bij end @@ -923,75 +927,6 @@ function lmul!(adjA::Adjoint{<:Any,<:UnitLowerTriangular}, B::StridedVecOrMat) B end -function lmul!(transA::Transpose{<:Any,<:UpperTriangular}, B::StridedVecOrMat) - A = transA.parent - m, n = size(B, 1), size(B, 2) - if m != size(A, 1) - throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) - end - for j = 1:n - for i = m:-1:1 - Bij = transpose(A.data[i,i]) * B[i,j] - for k = 1:i - 1 - Bij += transpose(A.data[k,i]) * B[k,j] - end - B[i,j] = Bij - end - end - B -end -function lmul!(transA::Transpose{<:Any,<:UnitUpperTriangular}, B::StridedVecOrMat) - A = transA.parent - m, n = size(B, 1), size(B, 2) - if m != size(A, 1) - throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) - end - for j = 1:n - for i = m:-1:1 - Bij = B[i,j] - for k = 1:i - 1 - Bij += transpose(A.data[k,i]) * B[k,j] - end - B[i,j] = Bij - end - end - B -end - -function lmul!(transA::Transpose{<:Any,<:LowerTriangular}, B::StridedVecOrMat) - A = transA.parent - m, n = size(B, 1), size(B, 2) - if m != size(A, 1) - throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) - end - for j = 1:n - for i = 1:m - Bij = transpose(A.data[i,i]) * B[i,j] - for k = i + 1:m - Bij += transpose(A.data[k,i]) * B[k,j] - end - B[i,j] = Bij - end - end - B -end -function lmul!(transA::Transpose{<:Any,<:UnitLowerTriangular}, B::StridedVecOrMat) - A = transA.parent - m, n = size(B, 1), size(B, 2) - if m != size(A, 1) - throw(DimensionMismatch("right hand side B needs first dimension of size $(size(A,1)), has size $m")) - end - for j = 1:n - for i = 1:m - Bij = B[i,j] - for k = i + 1:m - Bij += transpose(A.data[k,i]) * B[k,j] - end - B[i,j] = Bij - end - end - B -end function rmul!(A::StridedMatrix, B::UpperTriangular) m, n = size(A) @@ -1873,21 +1808,14 @@ for mat in (:AbstractVector, :AbstractMatrix) copyto!(BB, B) lmul!(convert(AbstractArray{TAB}, A), BB) end - function *(adjA::Adjoint{<:Any,<:AbstractTriangular}, B::$mat) + function *(adjA::AdjOrTrans{<:Any,<:AbstractTriangular}, B::$mat) require_one_based_indexing(B) A = adjA.parent + f = functor(adjA) TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) BB = similar(B, TAB, size(B)) copyto!(BB, B) - lmul!(adjoint(convert(AbstractArray{TAB}, A)), BB) - end - function *(transA::Transpose{<:Any,<:AbstractTriangular}, B::$mat) - require_one_based_indexing(B) - A = transA.parent - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) - BB = similar(B, TAB, size(B)) - copyto!(BB, B) - lmul!(transpose(convert(AbstractArray{TAB}, A)), BB) + lmul!(f(convert(AbstractArray{TAB}, A)), BB) end end ### Left division with triangle to the left hence rhs cannot be transposed. No quotients. @@ -2004,21 +1932,14 @@ function *(A::AbstractMatrix, B::AbstractTriangular) copyto!(AA, A) rmul!(AA, convert(AbstractArray{TAB}, B)) end -function *(A::AbstractMatrix, adjB::Adjoint{<:Any,<:AbstractTriangular}) +function *(A::AbstractMatrix, adjB::AdjOrTrans{<:Any,<:AbstractTriangular}) require_one_based_indexing(A) B = adjB.parent + f = functor(adjB) TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) AA = similar(A, TAB, size(A)) copyto!(AA, A) - rmul!(AA, adjoint(convert(AbstractArray{TAB}, B))) -end -function *(A::AbstractMatrix, transB::Transpose{<:Any,<:AbstractTriangular}) - require_one_based_indexing(A) - B = transB.parent - TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))) - AA = similar(A, TAB, size(A)) - copyto!(AA, A) - rmul!(AA, transpose(convert(AbstractArray{TAB}, B))) + rmul!(AA, f(convert(AbstractArray{TAB}, B))) end # ambiguity resolution with definitions in linalg/rowvector.jl *(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent) diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 437c5a8683eea..8060b52adbcfd 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -1,6 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license import LinearAlgebra: checksquare +using LinearAlgebra: functor, AdjOrTrans using Random: rand! ## sparse matrix multiplication @@ -55,8 +56,10 @@ end *(A::AbstractSparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} = (T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, one(T), zero(T))) -function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number) +function mul!(C::StridedVecOrMat, adjA::AdjOrTrans{<:Any,<:SparseMatrixCSC}, + B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number) A = adjA.parent + f = functor(adjA) size(A, 2) == size(C, 1) || throw(DimensionMismatch()) size(A, 1) == size(B, 1) || throw(DimensionMismatch()) size(B, 2) == size(C, 2) || throw(DimensionMismatch()) @@ -69,44 +72,18 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC} @inbounds for col = 1:size(A, 2) tmp = zero(eltype(C)) for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1) - tmp += adjoint(nzv[j])*B[rv[j],k] + tmp += f(nzv[j])*B[rv[j],k] end C[col,k] += tmp * α end end C end -*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} = +*(adjA::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} = (T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, one(T), zero(T))) -*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, B::AdjOrTransStridedMatrix{Tx}) where {TA,S,Tx} = +*(adjA::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, B::AdjOrTransStridedMatrix{Tx}) where {TA,S,Tx} = (T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, one(T), zero(T))) -function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number) - A = transA.parent - size(A, 2) == size(C, 1) || throw(DimensionMismatch()) - size(A, 1) == size(B, 1) || throw(DimensionMismatch()) - size(B, 2) == size(C, 2) || throw(DimensionMismatch()) - nzv = nonzeros(A) - rv = rowvals(A) - if β != 1 - β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) - end - for k = 1:size(C, 2) - @inbounds for col = 1:size(A, 2) - tmp = zero(eltype(C)) - for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1) - tmp += transpose(nzv[j])*B[rv[j],k] - end - C[col,k] += tmp * α - end - end - C -end -*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} = - (T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(transA, 1)), transA, x, one(T), zero(T))) -*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, B::AdjOrTransStridedMatrix{Tx}) where {TA,S,Tx} = - (T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, one(T), zero(T))) - # For compatibility with dense multiplication API. Should be deleted when dense multiplication # API is updated to follow BLAS API. mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedMatrix}) =