-
Notifications
You must be signed in to change notification settings - Fork 45
Hessian sparsity detection #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
5517877
nonlinearity tracking for Hessian sparsity
shashi f5d53d5
turn TermCombination into Sparsity
shashi 19d01d6
optimize squaring
shashi 516f616
minor refactor
shashi 55c74cb
make Hessian sparsity symmetric, change sparsity and hsparsity to ret…
shashi 0c8cd9f
Refactor & Merge remote-tracking branch 'origin/master' into s/hessian
shashi 611e9e0
throw when a function of unknown linearity is called as a leaf-call
shashi ae9e4c6
optimize squaring with an equivalent expression
shashi ffa488e
overdub unsafe_copyto!
shashi 7b1f9ca
add SpecialFunctions dependency
shashi e9d8788
take all branches -- treat comparison as linear
shashi f369d50
make max and min linear
shashi 135dc9d
more linearity info, getting broadcast to work
shashi e3cf262
handle getindex when index is tagged -- fixes iterate with state
shashi ad1463d
overdub copy on input
shashi a11d95f
the function itself maybe tagged, peel the tag off before checking li…
shashi 78c19e4
revert some frantic linearity mess
shashi 536d10a
hessian test helpers
shashi 6c42074
reroute blas calls to generic methods
shashi e36801f
use Set of Dicts for terms
shashi 1fb80d6
fix a wrecked branch
shashi dab5ba6
add hessian sparsity detection tests
shashi 58c2402
fix ambiguity for x[] where x is tagged
shashi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
using LinearAlgebra | ||
import LinearAlgebra.BLAS | ||
|
||
# generic implementations | ||
|
||
macro reroute(f, g) | ||
quote | ||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof($(esc(f))), | ||
args...) | ||
println("rerouted") | ||
Cassette.overdub( | ||
ctx, | ||
invoke, | ||
$(esc(g.args[1])), | ||
$(esc(:(Tuple{$(g.args[2:end]...)}))), | ||
args...) | ||
end | ||
end | ||
end | ||
|
||
@reroute BLAS.dot dot(Any, Any) | ||
@reroute BLAS.axpy! axpy!(Any, | ||
AbstractArray, | ||
AbstractArray) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
using Cassette | ||
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse | ||
import Core: SSAValue | ||
using SparseArrays | ||
|
||
# Tags: | ||
Cassette.@context HessianSparsityContext | ||
|
||
const TaggedOf{T} = Tagged{A, T} where A | ||
|
||
const HTagType = Union{Input, TermCombination} | ||
Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType | ||
|
||
istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination) | ||
|
||
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x) | ||
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata) | ||
|
||
# getindex on the input | ||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(getindex), | ||
X::Tagged, | ||
idx::Tagged...) | ||
if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx) | ||
error("getindex call depends on input. Cannot determine Hessian sparsity") | ||
end | ||
Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) | ||
end | ||
|
||
# plugs an ambiguity | ||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(getindex), | ||
X::Tagged) | ||
Cassette.recurse(ctx, f, X) | ||
end | ||
|
||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(getindex), | ||
X::Tagged, | ||
idx::Integer...) | ||
if ismetatype(X, ctx, Input) | ||
val = Cassette.fallback(ctx, f, X, idx...) | ||
i = LinearIndices(untag(X, ctx))[idx...] | ||
tag(val, ctx, TermCombination(Set([Dict(i=>1)]))) | ||
else | ||
Cassette.recurse(ctx, f, X, idx...) | ||
end | ||
end | ||
|
||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(Base.unsafe_copyto!), | ||
X::Tagged, | ||
xstart, | ||
Y::Tagged, | ||
ystart, | ||
len) | ||
if ismetatype(Y, ctx, Input) | ||
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) | ||
nometa = Cassette.NoMetaMeta() | ||
X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1) | ||
val | ||
else | ||
Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) | ||
end | ||
end | ||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(copy), | ||
X::Tagged) | ||
if ismetatype(X, ctx, Input) | ||
val = Cassette.fallback(ctx, f, X) | ||
tag(val, ctx, Input()) | ||
else | ||
Cassette.recurse(ctx, f, X) | ||
end | ||
end | ||
|
||
combine_terms(::Nothing, terms...) = one(TermCombination) | ||
|
||
# 1-arg functions | ||
combine_terms(::Val{true}, term) = term | ||
combine_terms(::Val{false}, term) = term * term | ||
|
||
# 2-arg functions | ||
function combine_terms(::Val{linearity}, term1, term2) where linearity | ||
|
||
linear11, linear22, linear12 = linearity | ||
term = zero(TermCombination) | ||
if linear11 | ||
if !linear12 | ||
term += term1 | ||
end | ||
else | ||
term += term1 * term1 | ||
end | ||
|
||
if linear22 | ||
if !linear12 | ||
term += term2 | ||
end | ||
else | ||
term += term2 * term2 | ||
end | ||
|
||
if linear12 | ||
term += term1 + term2 | ||
else | ||
term += term1 * term2 | ||
end | ||
term | ||
end | ||
|
||
|
||
# Hessian overdub | ||
# | ||
function getterms(ctx, x) | ||
ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination) | ||
end | ||
|
||
function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) | ||
t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) | ||
val = Cassette.fallback(ctx, f, args...) | ||
tag(val, ctx, t) | ||
end | ||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f::typeof(getproperty), | ||
x::Tagged, prop) | ||
if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx)) | ||
error("property of a non-constant term accessed") | ||
else | ||
Cassette.recurse(ctx, f, x, prop) | ||
end | ||
end | ||
|
||
haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs) | ||
linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs) | ||
|
||
function Cassette.overdub(ctx::HessianSparsityContext, | ||
f, | ||
args...) | ||
tainted = any(x->ismetatype(x, ctx, TermCombination), args) | ||
val = if tainted && haslinearity(ctx, f, Val{nfields(args)}()) | ||
l = linearity(ctx, f, Val{nfields(args)}()) | ||
hessian_overdub(ctx, f, l, args...) | ||
else | ||
val = Cassette.recurse(ctx, f, args...) | ||
#if tainted && !ismetatype(val, ctx, TermCombination) | ||
# @warn("Don't know the linearity of function $f") | ||
#end | ||
val | ||
end | ||
val | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using SpecialFunctions | ||
import Base.Broadcast | ||
|
||
const constant_funcs = [] | ||
|
||
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj] | ||
|
||
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] | ||
|
||
diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert] | ||
diadic_of_linearity(::Val{(true, true, false)}) = [*] | ||
#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] | ||
diadic_of_linearity(::Val{(true, false, false)}) = [/] | ||
diadic_of_linearity(::Val{(false, true, false)}) = [\] | ||
diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] | ||
diadic_of_linearity(::Val) = [] | ||
|
||
haslinearity(f, nargs) = false | ||
|
||
# some functions strip the linearity metadata | ||
|
||
for f in constant_funcs | ||
@eval begin | ||
haslinearity(::typeof($f), ::Val) = true | ||
linearity(::typeof($f), ::Val) = nothing | ||
end | ||
end | ||
|
||
# linearity of a single input function is either | ||
# Val{true}() or Val{false}() | ||
# | ||
for f in monadic_linear | ||
@eval begin | ||
haslinearity(::typeof($f), ::Val{1}) = true | ||
linearity(::typeof($f), ::Val{1}) = Val{true}() | ||
end | ||
end | ||
# linearity of a 2-arg function is: | ||
# Val{(linear11, linear22, linear12)}() | ||
# | ||
# linearIJ refers to the zeroness of d^2/dxIxJ | ||
for f in monadic_nonlinear | ||
@eval begin | ||
haslinearity(::typeof($f), ::Val{1}) = true | ||
linearity(::typeof($f), ::Val{1}) = Val{false}() | ||
end | ||
end | ||
|
||
for linearity_mask = 0:2^3-1 | ||
lin = Val{map(x->x!=0, (linearity_mask & 4, | ||
linearity_mask & 2, | ||
linearity_mask & 1))}() | ||
|
||
for f in diadic_of_linearity(lin) | ||
@eval begin | ||
haslinearity(::typeof($f), ::Val{2}) = true | ||
linearity(::typeof($f), ::Val{2}) = $lin | ||
end | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment that