Skip to content

Commit 6341f93

Browse files
Merge pull request #46 from shashi/s/hessian
Hessian sparsity detection
2 parents 05a90f9 + 58c2402 commit 6341f93

File tree

13 files changed

+485
-5
lines changed

13 files changed

+485
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1516
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

src/SparseDiffTools.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export contract_color,
2727
numback_hesvec,numback_hesvec!,
2828
autoback_hesvec,autoback_hesvec!,
2929
JacVec,HesVec,HesVecGrad,
30-
Sparsity, sparsity!
30+
Sparsity, sparsity!, hsparsity
3131

3232

3333
include("coloring/high_level.jl")
@@ -40,5 +40,9 @@ include("program_sparsity/program_sparsity.jl")
4040
include("program_sparsity/sparsity_tracker.jl")
4141
include("program_sparsity/path.jl")
4242
include("program_sparsity/take_all_branches.jl")
43+
include("program_sparsity/terms.jl")
44+
include("program_sparsity/linearity.jl")
45+
include("program_sparsity/hessian.jl")
46+
include("program_sparsity/blas.jl")
4347

4448
end # module

src/program_sparsity/blas.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using LinearAlgebra
2+
import LinearAlgebra.BLAS
3+
4+
# generic implementations
5+
6+
macro reroute(f, g)
7+
quote
8+
function Cassette.overdub(ctx::HessianSparsityContext,
9+
f::typeof($(esc(f))),
10+
args...)
11+
println("rerouted")
12+
Cassette.overdub(
13+
ctx,
14+
invoke,
15+
$(esc(g.args[1])),
16+
$(esc(:(Tuple{$(g.args[2:end]...)}))),
17+
args...)
18+
end
19+
end
20+
end
21+
22+
@reroute BLAS.dot dot(Any, Any)
23+
@reroute BLAS.axpy! axpy!(Any,
24+
AbstractArray,
25+
AbstractArray)

src/program_sparsity/hessian.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using Cassette
2+
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
3+
import Core: SSAValue
4+
using SparseArrays
5+
6+
# Tags:
7+
Cassette.@context HessianSparsityContext
8+
9+
const TaggedOf{T} = Tagged{A, T} where A
10+
11+
const HTagType = Union{Input, TermCombination}
12+
Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType
13+
14+
istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination)
15+
16+
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x)
17+
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata)
18+
19+
# getindex on the input
20+
function Cassette.overdub(ctx::HessianSparsityContext,
21+
f::typeof(getindex),
22+
X::Tagged,
23+
idx::Tagged...)
24+
if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx)
25+
error("getindex call depends on input. Cannot determine Hessian sparsity")
26+
end
27+
Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...)
28+
end
29+
30+
# plugs an ambiguity
31+
function Cassette.overdub(ctx::HessianSparsityContext,
32+
f::typeof(getindex),
33+
X::Tagged)
34+
Cassette.recurse(ctx, f, X)
35+
end
36+
37+
function Cassette.overdub(ctx::HessianSparsityContext,
38+
f::typeof(getindex),
39+
X::Tagged,
40+
idx::Integer...)
41+
if ismetatype(X, ctx, Input)
42+
val = Cassette.fallback(ctx, f, X, idx...)
43+
i = LinearIndices(untag(X, ctx))[idx...]
44+
tag(val, ctx, TermCombination(Set([Dict(i=>1)])))
45+
else
46+
Cassette.recurse(ctx, f, X, idx...)
47+
end
48+
end
49+
50+
function Cassette.overdub(ctx::HessianSparsityContext,
51+
f::typeof(Base.unsafe_copyto!),
52+
X::Tagged,
53+
xstart,
54+
Y::Tagged,
55+
ystart,
56+
len)
57+
if ismetatype(Y, ctx, Input)
58+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
59+
nometa = Cassette.NoMetaMeta()
60+
X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1)
61+
val
62+
else
63+
Cassette.recurse(ctx, f, X, xstart, Y, ystart, len)
64+
end
65+
end
66+
function Cassette.overdub(ctx::HessianSparsityContext,
67+
f::typeof(copy),
68+
X::Tagged)
69+
if ismetatype(X, ctx, Input)
70+
val = Cassette.fallback(ctx, f, X)
71+
tag(val, ctx, Input())
72+
else
73+
Cassette.recurse(ctx, f, X)
74+
end
75+
end
76+
77+
combine_terms(::Nothing, terms...) = one(TermCombination)
78+
79+
# 1-arg functions
80+
combine_terms(::Val{true}, term) = term
81+
combine_terms(::Val{false}, term) = term * term
82+
83+
# 2-arg functions
84+
function combine_terms(::Val{linearity}, term1, term2) where linearity
85+
86+
linear11, linear22, linear12 = linearity
87+
term = zero(TermCombination)
88+
if linear11
89+
if !linear12
90+
term += term1
91+
end
92+
else
93+
term += term1 * term1
94+
end
95+
96+
if linear22
97+
if !linear12
98+
term += term2
99+
end
100+
else
101+
term += term2 * term2
102+
end
103+
104+
if linear12
105+
term += term1 + term2
106+
else
107+
term += term1 * term2
108+
end
109+
term
110+
end
111+
112+
113+
# Hessian overdub
114+
#
115+
function getterms(ctx, x)
116+
ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination)
117+
end
118+
119+
function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...)
120+
t = combine_terms(linearity, map(x->getterms(ctx, x), args)...)
121+
val = Cassette.fallback(ctx, f, args...)
122+
tag(val, ctx, t)
123+
end
124+
function Cassette.overdub(ctx::HessianSparsityContext,
125+
f::typeof(getproperty),
126+
x::Tagged, prop)
127+
if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx))
128+
error("property of a non-constant term accessed")
129+
else
130+
Cassette.recurse(ctx, f, x, prop)
131+
end
132+
end
133+
134+
haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs)
135+
linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs)
136+
137+
function Cassette.overdub(ctx::HessianSparsityContext,
138+
f,
139+
args...)
140+
tainted = any(x->ismetatype(x, ctx, TermCombination), args)
141+
val = if tainted && haslinearity(ctx, f, Val{nfields(args)}())
142+
l = linearity(ctx, f, Val{nfields(args)}())
143+
hessian_overdub(ctx, f, l, args...)
144+
else
145+
val = Cassette.recurse(ctx, f, args...)
146+
#if tainted && !ismetatype(val, ctx, TermCombination)
147+
# @warn("Don't know the linearity of function $f")
148+
#end
149+
val
150+
end
151+
val
152+
end

src/program_sparsity/linearity.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using SpecialFunctions
2+
import Base.Broadcast
3+
4+
const constant_funcs = []
5+
6+
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]
7+
8+
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]
9+
10+
diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert]
11+
diadic_of_linearity(::Val{(true, true, false)}) = [*]
12+
#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1]
13+
diadic_of_linearity(::Val{(true, false, false)}) = [/]
14+
diadic_of_linearity(::Val{(false, true, false)}) = [\]
15+
diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta]
16+
diadic_of_linearity(::Val) = []
17+
18+
haslinearity(f, nargs) = false
19+
20+
# some functions strip the linearity metadata
21+
22+
for f in constant_funcs
23+
@eval begin
24+
haslinearity(::typeof($f), ::Val) = true
25+
linearity(::typeof($f), ::Val) = nothing
26+
end
27+
end
28+
29+
# linearity of a single input function is either
30+
# Val{true}() or Val{false}()
31+
#
32+
for f in monadic_linear
33+
@eval begin
34+
haslinearity(::typeof($f), ::Val{1}) = true
35+
linearity(::typeof($f), ::Val{1}) = Val{true}()
36+
end
37+
end
38+
# linearity of a 2-arg function is:
39+
# Val{(linear11, linear22, linear12)}()
40+
#
41+
# linearIJ refers to the zeroness of d^2/dxIxJ
42+
for f in monadic_nonlinear
43+
@eval begin
44+
haslinearity(::typeof($f), ::Val{1}) = true
45+
linearity(::typeof($f), ::Val{1}) = Val{false}()
46+
end
47+
end
48+
49+
for linearity_mask = 0:2^3-1
50+
lin = Val{map(x->x!=0, (linearity_mask & 4,
51+
linearity_mask & 2,
52+
linearity_mask & 1))}()
53+
54+
for f in diadic_of_linearity(lin)
55+
@eval begin
56+
haslinearity(::typeof($f), ::Val{2}) = true
57+
linearity(::typeof($f), ::Val{2}) = $lin
58+
end
59+
end
60+
end

src/program_sparsity/program_sparsity.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,28 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)),
3838
alldone(path) && break
3939
reset!(path)
4040
end
41-
sparsity
41+
sparse(sparsity)
42+
end
43+
44+
function hsparsity(f, X, args...; verbose=true)
45+
46+
terms = zero(TermCombination)
47+
path = Path()
48+
while true
49+
ctx = HessianSparsityContext(metadata=path, pass=BranchesPass)
50+
ctx = Cassette.enabletagging(ctx, f)
51+
ctx = Cassette.disablehooks(ctx)
52+
val = Cassette.recurse(ctx,
53+
f,
54+
tag(X, ctx, Input()),
55+
# TODO: make this recursive
56+
map(arg -> arg isa Fixed ?
57+
arg.value : tag(arg, ctx, one(TermCombination)), args)...)
58+
terms += metadata(val, ctx)
59+
verbose && println("Explored path: ", path)
60+
alldone(path) && break
61+
reset!(path)
62+
end
63+
64+
_sparse(terms, length(X))
4265
end

src/program_sparsity/take_all_branches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
istainted(ctx, x) = ismetatype(x, ctx, ProvinanceSet)
1+
istainted(ctx::SparsityContext, x) = ismetatype(x, ctx, ProvinanceSet)
22

33
Cassette.overdub(ctx::SparsityContext, f::typeof(istainted), x) = istainted(ctx, x)
44
Cassette.overdub(ctx::SparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx)

0 commit comments

Comments
 (0)