diff --git a/base/sparse/umfpack.jl b/base/sparse/umfpack.jl index aeaa590dc82db..496faf5552d89 100644 --- a/base/sparse/umfpack.jl +++ b/base/sparse/umfpack.jl @@ -383,42 +383,61 @@ function nnz(lu::UmfpackLU) end ### Solve with Factorization -for (f!, umfpack) in ((:A_ldiv_B!, :UMFPACK_A), - (:Ac_ldiv_B!, :UMFPACK_At), - (:At_ldiv_B!, :UMFPACK_Aat)) - @eval begin - function $f!{T<:UMFVTypes}(x::StridedVecOrMat{T}, lu::UmfpackLU{T}, b::StridedVecOrMat{T}) - n = size(x, 2) - if n != size(b, 2) - throw(DimensionMismatch("in and output arrays must have the same number of columns")) - end - for j in 1:n - solve!(view(x, :, j), lu, view(b, :, j), $umfpack) - end - return x - end - $f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVector{T}) = $f!(b, lu, copy(b)) - $f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedMatrix{T}) = $f!(b, lu, copy(b)) - - function $f!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb}) - m, n = size(x, 1), size(x, 2) - if n != size(b, 2) - throw(DimensionMismatch("in and output arrays must have the same number of columns")) - end - # TODO: Optionally let user allocate these and pass in somehow - r = similar(b, Float64, m) - i = similar(b, Float64, m) - for j in 1:n - solve!(r, lu, convert(Vector{Float64}, real(view(b, :, j))), $umfpack) - solve!(i, lu, convert(Vector{Float64}, imag(view(b, :, j))), $umfpack) - - map!((t,s) -> t + im*s, view(x, :, j), r, i) - end - return x - end - $f!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVector{Tb}) = $f!(b, lu, copy(b)) +A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = A_ldiv_B!(B, lu, copy(B)) +At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = At_ldiv_B!(B, lu, copy(B)) +Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = Ac_ldiv_B!(B, lu, copy(B)) +A_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = A_ldiv_B!(B, lu, copy(B)) +At_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = At_ldiv_B!(B, lu, copy(B)) +Ac_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = Ac_ldiv_B!(B, lu, copy(B)) + +A_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_A) +At_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_At) +Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_Aat) +A_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_A) +At_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_At) +Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) = + _Aq_ldiv_B!(X, lu, B, UMFPACK_Aat) + +function _Aq_ldiv_B!(X::StridedVecOrMat, lu::UmfpackLU, B::StridedVecOrMat, transposeoptype) + if size(X, 2) != size(B, 2) + throw(DimensionMismatch("input and output arrays must have same number of columns")) + end + _AqldivB_kernel!(X, lu, B, transposeoptype) + return X +end +function _AqldivB_kernel!{T<:UMFVTypes}(x::StridedVector{T}, lu::UmfpackLU{T}, + b::StridedVector{T}, transposeoptype) + solve!(x, lu, b, transposeoptype) +end +function _AqldivB_kernel!{T<:UMFVTypes}(X::StridedMatrix{T}, lu::UmfpackLU{T}, + B::StridedMatrix{T}, transposeoptype) + for col in 1:size(X, 2) + solve!(view(X, :, col), lu, view(B, :, col), transposeoptype) end end +function _AqldivB_kernel!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, + b::StridedVector{Tb}, transposeoptype) + r, i = similar(b, Float64), similar(b, Float64) + solve!(r, lu, Vector{Float64}(real(b)), transposeoptype) + solve!(i, lu, Vector{Float64}(imag(b)), transposeoptype) + map!(complex, x, r, i) +end +function _AqldivB_kernel!{Tb<:Complex}(X::StridedMatrix{Tb}, lu::UmfpackLU{Float64}, + B::StridedMatrix{Tb}, transposeoptype) + r = similar(B, Float64, size(B, 1)) + i = similar(B, Float64, size(B, 1)) + for j in 1:size(B, 2) + solve!(r, lu, Vector{Float64}(real(view(B, :, j))), transposeoptype) + solve!(i, lu, Vector{Float64}(imag(view(B, :, j))), transposeoptype) + map!(complex, view(X, :, j), r, i) + end +end + function getindex(lu::UmfpackLU, d::Symbol) L,U,p,q,Rs = umf_extract(lu) diff --git a/test/sparse/umfpack.jl b/test/sparse/umfpack.jl index dbfe2dfae0387..65f63f6ebdb20 100644 --- a/test/sparse/umfpack.jl +++ b/test/sparse/umfpack.jl @@ -143,3 +143,15 @@ let F = lufact(A) @test F[:p] == [3 ; 4 ; 2 ; 1] end + +# Test that A[c|t]_ldiv_B!{T<:Complex}(X::StridedMatrix{T}, lu::UmfpackLU{Float64}, +# B::StridedMatrix{T}) works as expected. +let N = 10, p = 0.5 + A = N*speye(N) + sprand(N, N, p) + X = zeros(Complex{Float64}, N, N) + B = complex.(rand(N, N), rand(N, N)) + luA, lufA = lufact(A), lufact(Array(A)) + @test A_ldiv_B!(copy(X), luA, B) ≈ A_ldiv_B!(copy(X), lufA, B) + @test At_ldiv_B!(copy(X), luA, B) ≈ At_ldiv_B!(copy(X), lufA, B) + @test Ac_ldiv_B!(copy(X), luA, B) ≈ Ac_ldiv_B!(copy(X), lufA, B) +end