diff --git a/Project.toml b/Project.toml index c2fd76fc..c5d2ee78 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,8 @@ julia = "1" [extras] DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "DiffEqDiffTools", "IterativeSolvers"] +test = ["Test", "DiffEqDiffTools", "IterativeSolvers", "Random"] diff --git a/src/differentiation/compute_jacobian_ad.jl b/src/differentiation/compute_jacobian_ad.jl index 47b0d1b0..bc397aed 100644 --- a/src/differentiation/compute_jacobian_ad.jl +++ b/src/differentiation/compute_jacobian_ad.jl @@ -1,9 +1,10 @@ -struct ForwardColorJacCache{T,T2,T3,T4,T5} +struct ForwardColorJacCache{T,T2,T3,T4,T5,T6} t::T fx::T2 dx::T3 p::T4 color::T5 + sparsity::T6 end function default_chunk_size(maxcolor) @@ -19,7 +20,8 @@ getsize(N::Integer) = N function ForwardColorJacCache(f,x,_chunksize = nothing; dx = nothing, - color=1:length(x)) + color=1:length(x), + sparsity::Union{SparseMatrixCSC,Nothing}=nothing) if _chunksize === nothing chunksize = default_chunk_size(maximum(color)) @@ -38,7 +40,7 @@ function ForwardColorJacCache(f,x,_chunksize = nothing; end p = generate_chunked_partials(x,color,chunksize) - ForwardColorJacCache(t,fx,_dx,p,color) + ForwardColorJacCache(t,fx,_dx,p,color,sparsity) end generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N)) @@ -78,8 +80,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}; dx = nothing, - color = eachindex(x)) - forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color)) + color = eachindex(x), + sparsity = J isa SparseMatrixCSC ? J : nothing) + forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color,sparsity=sparsity)) end function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, @@ -92,6 +95,7 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, dx = jac_cache.dx p = jac_cache.p color = jac_cache.color + sparsity = jac_cache.sparsity color_i = 1 chunksize = length(first(first(jac_cache.p))) @@ -99,8 +103,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, partial_i = p[i] t .= Dual{typeof(f)}.(x, partial_i) f(fx,t) - if J isa SparseMatrixCSC - rows_index, cols_index, val = findnz(J) + if sparsity isa SparseMatrixCSC + rows_index, cols_index, val = findnz(sparsity) for j in 1:chunksize dx .= partials.(fx, j) for k in 1:length(cols_index) diff --git a/test/test_ad.jl b/test/test_ad.jl index 1b5a80bc..113eb795 100644 --- a/test/test_ad.jl +++ b/test/test_ad.jl @@ -36,7 +36,22 @@ forwarddiff_color_jacobian!(_J1, f, x, color = repeat(1:3,10)) @test fcalls == 1 fcalls = 0 -jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10)) +jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10), sparsity = _J1) forwarddiff_color_jacobian!(_J1, f, x, jac_cache) @test _J1 ≈ J @test fcalls == 1 + +fcalls = 0 +_J1 = similar(_J) +_denseJ1 = collect(_J1) +forwarddiff_color_jacobian!(_denseJ1, f, x, color = repeat(1:3,10), sparsity = _J1) +@test _denseJ1 ≈ J +@test fcalls == 1 + +fcalls = 0 +_J1 = similar(_J) +_denseJ1 = collect(_J1) +jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10), sparsity = _J1) +forwarddiff_color_jacobian!(_denseJ1, f, x, jac_cache) +@test _denseJ1 ≈ J +@test fcalls == 1 diff --git a/test/test_integration.jl b/test/test_integration.jl index 5dd5dedb..e258ad8b 100644 --- a/test/test_integration.jl +++ b/test/test_integration.jl @@ -37,7 +37,14 @@ J = DiffEqDiffTools.finite_difference_jacobian(f, rand(30)) #Jacobian computed with coloring vectors fcalls = 0 -_J = 200 .* true_jac +_J = similar(true_jac) DiffEqDiffTools.finite_difference_jacobian!(_J, f, rand(30), color = colors) @test fcalls == 4 @test _J ≈ J + +fcalls = 0 +_J = similar(true_jac) +_denseJ = collect(_J) +DiffEqDiffTools.finite_difference_jacobian!(_denseJ, f, rand(30), color = colors, sparsity=_J) +@test fcalls == 4 +@test _denseJ ≈ J