Skip to content

implement jacvec products #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>","Chris Rackauckas <[email protected]>"]
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "0.1.0"

[deps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"

[compat]
julia = "1"

[extras]
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "DiffEqDiffTools"]
test = ["Test", "DiffEqDiffTools", "IterativeSolvers"]
14 changes: 12 additions & 2 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SparseDiffTools

using SparseArrays, LinearAlgebra, BandedMatrices, BlockBandedMatrices, LightGraphs, VertexSafeGraphs
using SparseArrays, LinearAlgebra, BandedMatrices, BlockBandedMatrices,
LightGraphs, VertexSafeGraphs, DiffEqDiffTools, ForwardDiff
using BlockBandedMatrices:blocksize,nblocks
using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD

Expand All @@ -9,12 +10,21 @@ export contract_color,
matrix2graph,
matrix_colors,
forwarddiff_color_jacobian!,
ForwardColorJacCache
ForwardColorJacCache,
auto_jacvec,auto_jacvec!,
num_jacvec,num_jacvec!,
num_hesvec,num_hesvec!,
numauto_hesvec,numauto_hesvec!,
autonum_hesvec,autonum_hesvec!,
num_hesvecgrad,num_hesvecgrad!,
auto_hesvecgrad,auto_hesvecgrad!,
JacVec,HesVec,HesVecGrad

include("coloring/high_level.jl")
include("coloring/contraction_coloring.jl")
include("coloring/greedy_d1_coloring.jl")
include("coloring/matrix2graph.jl")
include("differentiation/compute_jacobian_ad.jl")
include("differentiation/jaches_products.jl")

end # module
243 changes: 243 additions & 0 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
struct DeivVecTag end

# J(f(x))*v
function auto_jacvec!(du, f, x, v,
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v))
cache1 .= Dual{DeivVecTag}.(x, v)
f(cache2,cache1)
du .= partials.(cache2, 1)
end
function auto_jacvec(f, x, v)
partials.(f(Dual{DeivVecTag}.(x, v)), 1)
end

function num_jacvec!(du,f,x,v,cache1 = similar(v),
cache2 = similar(v);
compute_f0 = true)
compute_f0 && (f(cache1,x))
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
@. x += ϵ*v
f(cache2,x)
@. x -= ϵ*v
@. du = (cache2 - cache1)/ϵ
end

function num_jacvec(f,x,v,f0=nothing)
f0 === nothing ? _f0 = f(x) : _f0 = f0
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(minimum(x)))
(f(x.+ϵ.*v) .- f(x))./ϵ
end

function num_hesvec!(du,f,x,v,
cache1 = similar(v),
cache2 = similar(v),
cache3 = similar(v))
cache = DiffEqDiffTools.GradientCache(v[1],cache1,Val{:central})
g = let f=f,cache=cache
(dx,x) -> DiffEqDiffTools.finite_difference_gradient!(dx,f,x,cache)
end
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
@. x += ϵ*v
g(cache2,x)
@. x -= 2ϵ*v
g(cache3,x)
@. du = (cache2 - cache3)/(2ϵ)
end

function num_hesvec(f,x,v)
g = (x) -> DiffEqDiffTools.finite_difference_gradient(f,x)
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
x += ϵ*v
gxp = g(x)
x -= 2ϵ*v
gxm = g(x)
(gxp - gxm)/(2ϵ)
end

function numauto_hesvec!(du,f,x,v,
cache = ForwardDiff.GradientConfig(f,v),
cache1 = similar(v),
cache2 = similar(v))
g = let f=f,x=x,cache=cache
g = (dx,x) -> ForwardDiff.gradient!(dx,f,x,cache)
end
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
@. x += ϵ*v
g(cache1,x)
@. x -= 2ϵ*v
g(cache2,x)
@. du = (cache1 - cache2)/(2ϵ)
end

function numauto_hesvec(f,x,v)
g = (x) -> ForwardDiff.gradient(f,x)
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
x += ϵ*v
gxp = g(x)
x -= 2ϵ*v
gxm = g(x)
(gxp - gxm)/(2ϵ)
end

function autonum_hesvec!(du,f,x,v,
cache1 = similar(v),
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
cache = DiffEqDiffTools.GradientCache(v[1],cache1,Val{:central})
g = (dx,x) -> DiffEqDiffTools.finite_difference_gradient!(dx,f,x,cache)
cache2 .= Dual{DeivVecTag}.(x, v)
g(cache3,cache2)
du .= partials.(cache3, 1)
end

function autonum_hesvec(f,x,v)
g = (x) -> DiffEqDiffTools.finite_difference_gradient(f,x)
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
end

function num_hesvecgrad!(du,g,x,v,
cache2 = similar(v),
cache3 = similar(v))
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
@. x += ϵ*v
g(cache2,x)
@. x -= 2ϵ*v
g(cache3,x)
@. du = (cache2 - cache3)/(2ϵ)
end

function num_hesvecgrad(g,x,v)
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
x += ϵ*v
gxp = g(x)
x -= 2ϵ*v
gxm = g(x)
(gxp - gxm)/(2ϵ)
end

function auto_hesvecgrad!(du,g,x,v,
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
cache2 .= Dual{DeivVecTag}.(x, v)
g(cache3,cache2)
du .= partials.(cache3, 1)
end

function auto_hesvecgrad(g,x,v)
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
end

### Operator Forms

mutable struct JacVec{F,T1,T2,uType}
f::F
cache1::T1
cache2::T2
u::uType
autodiff::Bool
end

function JacVec(f,u::AbstractArray;autodiff=true)
if autodiff
cache1 = ForwardDiff.Dual{DeivVecTag}.(u, u)
cache2 = ForwardDiff.Dual{DeivVecTag}.(u, u)
else
cache1 = similar(u)
cache2 = similar(u)
end
JacVec(f,cache1,cache2,u,autodiff)
end

Base.size(L::JacVec) = (length(L.cache1),length(L.cache1))
Base.size(L::JacVec,i::Int) = length(L.cache1)
Base.:*(L::JacVec,x::AbstractVector) = L.autodiff ? auto_jacvec(_u->L.f(_u),L.u,x) : num_jacvec(_u->L.f(_u),L.u,x)

function LinearAlgebra.mul!(du::AbstractVector,L::JacVec,v::AbstractVector)
if L.autodiff
auto_jacvec!(du,(_du,_u)->L.f(_du,_u),L.u,v,L.cache1,L.cache2)
else
num_jacvec!(du,(_du,_u)->L.f(_du,_u),L.u,v,L.cache1,L.cache2)
end
end

mutable struct HesVec{F,T1,T2,uType}
f::F
cache1::T1
cache2::T2
cache3::T2
u::uType
autodiff::Bool
end

function HesVec(f,u::AbstractArray;autodiff=true)
if autodiff
cache1 = ForwardDiff.GradientConfig(f,u)
cache2 = similar(u)
cache3 = similar(u)
else
cache1 = similar(u)
cache2 = similar(u)
cache3 = similar(u)
end
HesVec(f,cache1,cache2,cache3,u,autodiff)
end

Base.size(L::HesVec) = (length(L.cache2),length(L.cache2))
Base.size(L::HesVec,i::Int) = length(L.cache2)
Base.:*(L::HesVec,x::AbstractVector) = L.autodiff ? numauto_hesvec(L.f,L.u,x) : num_hesvec(L.f,L.u,x)

function LinearAlgebra.mul!(du::AbstractVector,L::HesVec,v::AbstractVector)
if L.autodiff
numauto_hesvec!(du,L.f,L.u,v,L.cache1,L.cache2,L.cache3)
else
num_hesvec!(du,L.f,L.u,v,L.cache1,L.cache2,L.cache3)
end
end

mutable struct HesVecGrad{G,T1,T2,uType}
g::G
cache1::T1
cache2::T2
u::uType
autodiff::Bool
end

function HesVecGrad(g,u::AbstractArray;autodiff=true)
if autodiff
cache1 = ForwardDiff.Dual{DeivVecTag}.(u, u)
cache2 = ForwardDiff.Dual{DeivVecTag}.(u, u)
else
cache1 = similar(u)
cache2 = similar(u)
end
HesVecGrad(g,cache1,cache2,u,autodiff)
end

Base.size(L::HesVecGrad) = (length(L.cache2),length(L.cache2))
Base.size(L::HesVecGrad,i::Int) = length(L.cache2)
Base.:*(L::HesVecGrad,x::AbstractVector) = L.autodiff ? auto_hesvecgrad(L.g,L.u,x) : num_hesvecgrad(L.g,L.u,x)

function LinearAlgebra.mul!(du::AbstractVector,L::HesVecGrad,v::AbstractVector)
if L.autodiff
auto_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
else
num_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ using Test
@testset "AD using color vector" begin include("test_ad.jl") end
@testset "Integration test" begin include("test_integration.jl") end
@testset "Special matrices" begin include("test_specialmatrices.jl") end
@testset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
Loading