Skip to content

Hessian sparsity detection #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
5517877
nonlinearity tracking for Hessian sparsity
shashi Jul 5, 2019
f5d53d5
turn TermCombination into Sparsity
shashi Jul 6, 2019
19d01d6
optimize squaring
shashi Jul 6, 2019
516f616
minor refactor
shashi Jul 6, 2019
55c74cb
make Hessian sparsity symmetric, change sparsity and hsparsity to ret…
shashi Jul 6, 2019
0c8cd9f
Refactor & Merge remote-tracking branch 'origin/master' into s/hessian
shashi Jul 6, 2019
611e9e0
throw when a function of unknown linearity is called as a leaf-call
shashi Jul 6, 2019
ae9e4c6
optimize squaring with an equivalent expression
shashi Jul 6, 2019
ffa488e
overdub unsafe_copyto!
shashi Jul 6, 2019
7b1f9ca
add SpecialFunctions dependency
shashi Jul 7, 2019
e9d8788
take all branches -- treat comparison as linear
shashi Jul 7, 2019
f369d50
make max and min linear
shashi Jul 7, 2019
135dc9d
more linearity info, getting broadcast to work
shashi Jul 7, 2019
e3cf262
handle getindex when index is tagged -- fixes iterate with state
shashi Jul 7, 2019
ad1463d
overdub copy on input
shashi Jul 7, 2019
a11d95f
the function itself maybe tagged, peel the tag off before checking li…
shashi Jul 7, 2019
78c19e4
revert some frantic linearity mess
shashi Jul 7, 2019
536d10a
hessian test helpers
shashi Jul 7, 2019
6c42074
reroute blas calls to generic methods
shashi Jul 7, 2019
e36801f
use Set of Dicts for terms
shashi Jul 8, 2019
1fb80d6
fix a wrecked branch
shashi Jul 8, 2019
dab5ba6
add hessian sparsity detection tests
shashi Jul 8, 2019
58c2402
fix ambiguity for x[] where x is tagged
shashi Jul 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
6 changes: 5 additions & 1 deletion src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export contract_color,
numback_hesvec,numback_hesvec!,
autoback_hesvec,autoback_hesvec!,
JacVec,HesVec,HesVecGrad,
Sparsity, sparsity!
Sparsity, sparsity!, hsparsity


include("coloring/high_level.jl")
Expand All @@ -40,5 +40,9 @@ include("program_sparsity/program_sparsity.jl")
include("program_sparsity/sparsity_tracker.jl")
include("program_sparsity/path.jl")
include("program_sparsity/take_all_branches.jl")
include("program_sparsity/terms.jl")
include("program_sparsity/linearity.jl")
include("program_sparsity/hessian.jl")
include("program_sparsity/blas.jl")

end # module
25 changes: 25 additions & 0 deletions src/program_sparsity/blas.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using LinearAlgebra
import LinearAlgebra.BLAS

# generic implementations

macro reroute(f, g)
quote
function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof($(esc(f))),
args...)
println("rerouted")
Cassette.overdub(
ctx,
invoke,
$(esc(g.args[1])),
$(esc(:(Tuple{$(g.args[2:end]...)}))),
args...)
end
end
end

@reroute BLAS.dot dot(Any, Any)
@reroute BLAS.axpy! axpy!(Any,
AbstractArray,
AbstractArray)
152 changes: 152 additions & 0 deletions src/program_sparsity/hessian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using Cassette
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
import Core: SSAValue
using SparseArrays

# Tags:
Cassette.@context HessianSparsityContext

const TaggedOf{T} = Tagged{A, T} where A

const HTagType = Union{Input, TermCombination}
Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType

istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination)

Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x)
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata)

# getindex on the input
function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(getindex),
X::Tagged,
idx::Tagged...)
if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx)
error("getindex call depends on input. Cannot determine Hessian sparsity")
end
Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...)
end

# plugs an ambiguity
function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(getindex),
X::Tagged)
Cassette.recurse(ctx, f, X)
end

function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(getindex),
X::Tagged,
idx::Integer...)
if ismetatype(X, ctx, Input)
val = Cassette.fallback(ctx, f, X, idx...)
i = LinearIndices(untag(X, ctx))[idx...]
tag(val, ctx, TermCombination(Set([Dict(i=>1)])))
else
Cassette.recurse(ctx, f, X, idx...)
end
end

function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(Base.unsafe_copyto!),
X::Tagged,
xstart,
Y::Tagged,
ystart,
len)
if ismetatype(Y, ctx, Input)
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
nometa = Cassette.NoMetaMeta()
X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1)
val
else
Cassette.recurse(ctx, f, X, xstart, Y, ystart, len)
end
end
function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(copy),
X::Tagged)
if ismetatype(X, ctx, Input)
val = Cassette.fallback(ctx, f, X)
tag(val, ctx, Input())
else
Cassette.recurse(ctx, f, X)
end
end

combine_terms(::Nothing, terms...) = one(TermCombination)

# 1-arg functions
combine_terms(::Val{true}, term) = term
combine_terms(::Val{false}, term) = term * term

# 2-arg functions
function combine_terms(::Val{linearity}, term1, term2) where linearity

linear11, linear22, linear12 = linearity
term = zero(TermCombination)
if linear11
if !linear12
term += term1
end
else
term += term1 * term1
end

if linear22
if !linear12
term += term2
end
else
term += term2 * term2
end

if linear12
term += term1 + term2
else
term += term1 * term2
end
term
end


# Hessian overdub
#
function getterms(ctx, x)
ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination)
end

function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...)
t = combine_terms(linearity, map(x->getterms(ctx, x), args)...)
val = Cassette.fallback(ctx, f, args...)
tag(val, ctx, t)
end
function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof(getproperty),
x::Tagged, prop)
if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx))
error("property of a non-constant term accessed")
else
Cassette.recurse(ctx, f, x, prop)
end
end

haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs)
linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs)

function Cassette.overdub(ctx::HessianSparsityContext,
f,
args...)
tainted = any(x->ismetatype(x, ctx, TermCombination), args)
val = if tainted && haslinearity(ctx, f, Val{nfields(args)}())
l = linearity(ctx, f, Val{nfields(args)}())
hessian_overdub(ctx, f, l, args...)
else
val = Cassette.recurse(ctx, f, args...)
#if tainted && !ismetatype(val, ctx, TermCombination)
# @warn("Don't know the linearity of function $f")
#end
val
end
val
end
60 changes: 60 additions & 0 deletions src/program_sparsity/linearity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using SpecialFunctions
import Base.Broadcast

const constant_funcs = []

const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]

const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]

diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment that

diadic_of_linearity(::Val{(true, true, false)}) = [*]
#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1]
diadic_of_linearity(::Val{(true, false, false)}) = [/]
diadic_of_linearity(::Val{(false, true, false)}) = [\]
diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta]
diadic_of_linearity(::Val) = []

haslinearity(f, nargs) = false

# some functions strip the linearity metadata

for f in constant_funcs
@eval begin
haslinearity(::typeof($f), ::Val) = true
linearity(::typeof($f), ::Val) = nothing
end
end

# linearity of a single input function is either
# Val{true}() or Val{false}()
#
for f in monadic_linear
@eval begin
haslinearity(::typeof($f), ::Val{1}) = true
linearity(::typeof($f), ::Val{1}) = Val{true}()
end
end
# linearity of a 2-arg function is:
# Val{(linear11, linear22, linear12)}()
#
# linearIJ refers to the zeroness of d^2/dxIxJ
for f in monadic_nonlinear
@eval begin
haslinearity(::typeof($f), ::Val{1}) = true
linearity(::typeof($f), ::Val{1}) = Val{false}()
end
end

for linearity_mask = 0:2^3-1
lin = Val{map(x->x!=0, (linearity_mask & 4,
linearity_mask & 2,
linearity_mask & 1))}()

for f in diadic_of_linearity(lin)
@eval begin
haslinearity(::typeof($f), ::Val{2}) = true
linearity(::typeof($f), ::Val{2}) = $lin
end
end
end
25 changes: 24 additions & 1 deletion src/program_sparsity/program_sparsity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,28 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)),
alldone(path) && break
reset!(path)
end
sparsity
sparse(sparsity)
end

function hsparsity(f, X, args...; verbose=true)

terms = zero(TermCombination)
path = Path()
while true
ctx = HessianSparsityContext(metadata=path, pass=BranchesPass)
ctx = Cassette.enabletagging(ctx, f)
ctx = Cassette.disablehooks(ctx)
val = Cassette.recurse(ctx,
f,
tag(X, ctx, Input()),
# TODO: make this recursive
map(arg -> arg isa Fixed ?
arg.value : tag(arg, ctx, one(TermCombination)), args)...)
terms += metadata(val, ctx)
verbose && println("Explored path: ", path)
alldone(path) && break
reset!(path)
end

_sparse(terms, length(X))
end
2 changes: 1 addition & 1 deletion src/program_sparsity/take_all_branches.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
istainted(ctx, x) = ismetatype(x, ctx, ProvinanceSet)
istainted(ctx::SparsityContext, x) = ismetatype(x, ctx, ProvinanceSet)

Cassette.overdub(ctx::SparsityContext, f::typeof(istainted), x) = istainted(ctx, x)
Cassette.overdub(ctx::SparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx)
Expand Down
Loading