Skip to content

make AD operators scimloperators #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5f0714e
ignore vim files
vpuri3 Feb 4, 2023
3fefab1
add scimlops
vpuri3 Feb 4, 2023
1691faa
JVP WIP
vpuri3 Feb 4, 2023
c5ef8ae
VJP stuff copy pasted from https://github.com/JuliaDiff/SparseDiffToo…
vpuri3 Feb 4, 2023
5d41abf
comments
vpuri3 Feb 4, 2023
0de693c
scimloperator compat: wait for https://github.com/SciML/SciMLOperator…
vpuri3 Feb 4, 2023
e1e7c95
make all these structs FunctionOperators
vpuri3 Feb 4, 2023
1559a9b
WIP
vpuri3 Feb 5, 2023
35ab990
comments.
vpuri3 Feb 6, 2023
e0729f8
comments
vpuri3 Feb 6, 2023
807d294
compiles
vpuri3 Feb 6, 2023
ea57d14
typos
vpuri3 Feb 6, 2023
1fad88b
tests
vpuri3 Feb 6, 2023
2e90ffc
typo
vpuri3 Feb 6, 2023
78eb0bf
cleanup
vpuri3 Feb 6, 2023
fbda265
vecjacprod tests
vpuri3 Feb 7, 2023
5aa1f0d
5 arg mul<bang> tests
vpuri3 Feb 7, 2023
4b0fcce
comments, 5-arg mul<bang>
vpuri3 Feb 7, 2023
1796880
uncomment tests
vpuri3 Feb 8, 2023
e1d8be1
v2
vpuri3 Feb 8, 2023
84542bb
redefine JacVec, HesVec, HesVecGrad, VecJac
vpuri3 Feb 8, 2023
9375983
tests
vpuri3 Feb 8, 2023
75974ed
no inplace tests for rev mode
vpuri3 Feb 8, 2023
efb05a5
fix typo
vpuri3 Feb 9, 2023
edac7af
dont need multiple update_coeff definitinos
vpuri3 Feb 9, 2023
866daa1
rm ODE JacVec tests
vpuri3 Feb 10, 2023
ec9e0f4
tricks compat
vpuri3 Feb 10, 2023
a8f374b
iip,oop
vpuri3 Feb 18, 2023
ce0febe
zygote
vpuri3 Feb 18, 2023
df9b45a
remove other arrayinterfaces in tests
ChrisRackauckas Feb 19, 2023
4c760e8
make structs immutable
vpuri3 Feb 19, 2023
f322ec7
Zygote
vpuri3 Feb 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
.DS_Store
.*.swp
.*.swo
Manifest.toml
/dev/

docs/build/
docs/site/
docs/site/
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "1.31.0"
version = "2.00.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,8 +13,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"

[compat]
Expand All @@ -25,8 +27,10 @@ DataStructures = "0.18"
FiniteDiff = "2.8.1"
ForwardDiff = "0.10"
Graphs = "1"
Requires = "1.0"
Requires = "1"
SciMLOperators = "0.1.19"
StaticArrays = "1"
Tricks = "0.1.6"
VertexSafeGraphs = "0.2"
julia = "1.6"

Expand Down
15 changes: 12 additions & 3 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ using DataStructures: DisjointSets, find_root!, union!

using ArrayInterface: matrix_colors

using SciMLOperators
import SciMLOperators: update_coefficients, update_coefficients!
using Tricks: static_hasmethod

abstract type AbstractAutoDiffVecProd end

export contract_color,
greedy_d1,
greedy_star1_coloring,
Expand All @@ -42,7 +48,8 @@ export contract_color,
autonum_hesvec, autonum_hesvec!,
num_hesvecgrad, num_hesvecgrad!,
auto_hesvecgrad, auto_hesvecgrad!,
JacVec, HesVec, HesVecGrad,
JacVec, HesVec, HesVecGrad, VecJac,
update_coefficients, update_coefficients!,
value!

include("coloring/high_level.jl")
Expand All @@ -64,8 +71,10 @@ parameterless_type(x::Type) = __parameterless_type(x)

function __init__()
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!
export numback_hesvec, numback_hesvec!,
autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!,
ZygoteVecJac, ZygoteHesVec

include("differentiation/vecjac_products_zygote.jl")
include("differentiation/jaches_products_zygote.jl")
Expand Down
174 changes: 94 additions & 80 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,112 +198,126 @@ end

### Operator Forms

struct JacVec{F, T1, T2, xType}
struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
f::F
cache1::T1
cache2::T2
x::xType
autodiff::Bool
u::U
cache::C
vecprod::V
vecprod!::V!
end

function JacVec(f, x::AbstractArray, tag = DeivVecTag(); autodiff = true)
if autodiff
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(x)))
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(x)))
else
cache1 = similar(x)
cache2 = similar(x)
end
JacVec(f, cache1, cache2, x, autodiff)
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
end

Base.eltype(L::JacVec) = eltype(L.x)
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
Base.size(L::JacVec, i::Int) = length(L.cache1)
function Base.:*(L::JacVec, v::AbstractVector)
L.autodiff ? auto_jacvec(_x -> L.f(_x), L.x, v) :
num_jacvec(_x -> L.f(_x), L.x, v)
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
copy!(L.u, u)
L
end

function LinearAlgebra.mul!(dy::AbstractVector, L::JacVec, v::AbstractVector)
if L.autodiff
auto_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
else
num_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
end
function (L::FwdModeAutoDiffVecProd)(v, p, t)
L.vecprod(L.f, L.u, v)
end

struct HesVec{F, T1, T2, xType}
f::F
cache1::T1
cache2::T2
cache3::T2
x::xType
autodiff::Bool
function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function HesVec(f, x::AbstractArray; autodiff = true)
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = ForwardDiff.GradientConfig(f, x)
cache2 = similar(x)
cache3 = similar(x)
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)
else
cache1 = similar(x)
cache2 = similar(x)
cache3 = similar(x)
cache1 = similar(u)
cache2 = similar(u)
end
HesVec(f, cache1, cache2, cache3, x, autodiff)
end

Base.size(L::HesVec) = (length(L.cache2), length(L.cache2))
Base.size(L::HesVec, i::Int) = length(L.cache2)
function Base.:*(L::HesVec, v::AbstractVector)
L.autodiff ? numauto_hesvec(L.f, L.x, v) : num_hesvec(L.f, L.x, v)
end
cache = (cache1, cache2,)

function LinearAlgebra.mul!(dy::AbstractVector, L::HesVec, v::AbstractVector)
if L.autodiff
numauto_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
else
num_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
vecprod = autodiff ? auto_jacvec : num_jacvec
vecprod! = autodiff ? auto_jacvec! : num_jacvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
end
end

struct HesVecGrad{G, T1, T2, uType}
g::G
cache1::T1
cache2::T2
x::uType
autodiff::Bool
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
)
end

function HesVecGrad(g, x::AbstractArray, tag = DeivVecTag(); autodiff = false)
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(x)))
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(tuple.(x)))
cache1 = ForwardDiff.GradientConfig(f, u)
cache2 = similar(u)
cache3 = similar(u)
else
cache1 = similar(x)
cache2 = similar(x)
cache1 = similar(u)
cache2 = similar(u)
cache3 = similar(u)
end
HesVecGrad(g, cache1, cache2, x, autodiff)
end

Base.size(L::HesVecGrad) = (length(L.cache2), length(L.cache2))
Base.size(L::HesVecGrad, i::Int) = length(L.cache2)
function Base.:*(L::HesVecGrad, v::AbstractVector)
L.autodiff ? auto_hesvecgrad(L.g, L.x, v) : num_hesvecgrad(L.g, L.x, v)
cache = (cache1, cache2, cache3,)

vecprod = autodiff ? numauto_hesvec : num_hesvec
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u).")
end

L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
)
end

function LinearAlgebra.mul!(dy::AbstractVector,
L::HesVecGrad,
v::AbstractVector)
if L.autodiff
auto_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)
else
num_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
cache1 = similar(u)
cache2 = similar(u)
end

cache = (cache1, cache2,)

vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
end

L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
)
end
#
45 changes: 40 additions & 5 deletions src/differentiation/jaches_products_zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ function numback_hesvec(f, x, v)
end

function autoback_hesvec!(dy, f, x, v,
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
eltype(x), 1
}.(x,
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
g(cache3, cache2)
dy .= partials.(cache3, 1)
g(cache2, cache1)
dy .= partials.(cache2, 1)
end

function autoback_hesvec(f, x, v)
Expand All @@ -48,3 +48,38 @@ function autoback_hesvec(f, x, v)
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end

### Operator Forms

function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(u)
else
cache1 = similar(u)
cache2 = similar(u)
end

cache = (cache1, cache2,)

vecprod = autodiff ? autoback_hesvec : numback_hesvec
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u).")
end

L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
)
end
#
Loading