From 551787713cb11b4eefaac0677b1e69c8a642ff77 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 5 Jul 2019 18:06:54 -0400 Subject: [PATCH 01/22] nonlinearity tracking for Hessian sparsity --- src/program_sparsity/hessian.jl | 96 ++++++++++++++++++++++++ src/program_sparsity/linearity.jl | 41 ++++++++++ src/program_sparsity/program_sparsity.jl | 18 +++++ src/program_sparsity/terms.jl | 27 +++++++ 4 files changed, 182 insertions(+) create mode 100644 src/program_sparsity/hessian.jl create mode 100644 src/program_sparsity/linearity.jl create mode 100644 src/program_sparsity/terms.jl diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl new file mode 100644 index 00000000..a6421d3c --- /dev/null +++ b/src/program_sparsity/hessian.jl @@ -0,0 +1,96 @@ +using Cassette +import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse +import Core: SSAValue +using SparseArrays + +# Tags: +Cassette.@context HessianSparsityContext + +include("terms.jl") +const HTagType = Union{Input, TermCombination} +Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType + +# getindex on the input +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged, + idx::Int...) + if ismetatype(X, ctx, Input) + i = LinearIndices(untag(X, ctx))[idx...] + val = Cassette.fallback(ctx, f, X, idx...) + tag(val, ctx, TermCombination([Dict(i=>1)])) + else + Cassette.recurse(ctx, f, X, idx...) + end +end + +# linearity of a single input function is either +# Val{true}() or Val{false}() +# +# linearity of a 2-arg function is: +# Val{(linear11, linear22, linear12)}() +# +# linearIJ refers to the zeroness of d^2/dxIxJ + +include("linearity.jl") +# 1-arg functions +combine_terms(::Val{true}, term) = term +combine_terms(::Val{false}, term) = term * term + +# 2-arg functions +function combine_terms(::Val{linearity}, term1, term2) where linearity + + linear11, linear22, linear12 = linearity + term = zero(TermCombination) + if linear11 + term += term1 + else + term += term1 * term1 + end + + if linear22 + term += term2 + else + term += term2 * term2 + end + + if linear12 + if !linear11 + term += term1 + end + if !linear22 + term += term2 + end + else + term += term1 * term2 + end + term +end + + +# Hessian overdub +# +function getterms(ctx, x::Tagged) + ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination) +end + +function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) + if any(x->ismetatype(x, ctx, TermCombination), args) + t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) + val = Cassette.fallback(ctx, f, args...) + tag(val, ctx, t) + else + Cassette.recurse(ctx, f, args...) + end +end + +function Cassette.overdub(ctx::HessianSparsityContext, + f, + args...) + if haslinearity(f, Val{nfields(args)}()) + l = linearity(f, Val{nfields(args)}()) + return hessian_overdub(ctx, f, l, args...) + else + return Cassette.recurse(ctx, f, args...) + end +end diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl new file mode 100644 index 00000000..02b0b84c --- /dev/null +++ b/src/program_sparsity/linearity.jl @@ -0,0 +1,41 @@ +using SpecialFunctions + +const monadic_linear = [deg2rad, +, rad2deg, transpose, -] +const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] + +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -] +diadic_of_linearity(::Val{(true, true, false)}) = [*] +diadic_of_linearity(::Val{(true, false, true)}) = [] +#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] +diadic_of_linearity(::Val{(true, false, false)}) = [/] +diadic_of_linearity(::Val{(false, true, true)}) = [] +diadic_of_linearity(::Val{(false, true, false)}) = [\] +diadic_of_linearity(::Val{(false, false, true)}) = [] +diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, max, min, mod, rem, lbeta, ^, beta] + +haslinearity(f, nargs) = false +for f in monadic_linear + @eval begin + haslinearity(::typeof($f), ::Val{1}) = true + linearity(::typeof($f), ::Val{1}) = Val{true}() + end +end +for f in monadic_nonlinear + @eval begin + haslinearity(::typeof($f), ::Val{1}) = true + linearity(::typeof($f), ::Val{1}) = Val{false}() + end +end + +for linearity_mask = 0:2^3-1 + lin = Val{map(x->x!=0, (linearity_mask & 4, + linearity_mask & 2, + linearity_mask & 1))}() + + for f in diadic_of_linearity(lin) + @eval begin + haslinearity(::typeof($f), ::Val{2}) = true + linearity(::typeof($f), ::Val{2}) = $lin + end + end +end diff --git a/src/program_sparsity/program_sparsity.jl b/src/program_sparsity/program_sparsity.jl index bedfb6a3..6c45c864 100644 --- a/src/program_sparsity/program_sparsity.jl +++ b/src/program_sparsity/program_sparsity.jl @@ -1,7 +1,10 @@ include("sparsity_tracker.jl") +include("hessian.jl") include("path.jl") include("take_all_branches.jl") +export Sparsity, sparsity, hsparsity + struct Fixed value end @@ -42,3 +45,18 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X))) end sparsity end + +function hsparsity(f, X, args...) + ctx = HessianSparsityContext() + ctx = Cassette.enabletagging(ctx, f) + ctx = Cassette.disablehooks(ctx) + + val = Cassette.recurse(ctx, + f, + tag(X, ctx, Input()), + # TODO: make this recursive + map(arg -> arg isa Fixed ? + arg.value : tag(arg, ctx, TermCombination([[]])), args)...) + + metadata(val, ctx), untag(val, ctx) +end diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl new file mode 100644 index 00000000..544402a3 --- /dev/null +++ b/src/program_sparsity/terms.jl @@ -0,0 +1,27 @@ +struct TermCombination + terms::Vector{Dict{Int, Int}} # idx => pow +end + +Base.zero(::Type{TermCombination}) = TermCombination([]) +Base.one(::Type{TermCombination}) = TermCombination([Dict{Int,Int}()]) + +function Base.:+(comb1::TermCombination, comb2::TermCombination) + TermCombination(vcat(comb1.terms, comb2.terms)) +end +Base.:+(comb1::TermCombination) = comb1 + +function _merge(dict1, dict2) + d = copy(dict1) + for (k, v) in dict2 + d[k] = get(dict1, k, 0) + v + end + d +end + +function Base.:*(comb1::TermCombination, comb2::TermCombination) + vec([_merge(dict1, dict2) + for dict1 in comb1.terms, + dict2 in comb2.terms]) |> TermCombination +end +Base.:*(comb1::TermCombination) = comb1 + From f5d53d5cf923d1b05780581676c13ca36f877741 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 01:03:43 -0400 Subject: [PATCH 02/22] turn TermCombination into Sparsity --- src/program_sparsity/program_sparsity.jl | 2 +- src/program_sparsity/terms.jl | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/program_sparsity/program_sparsity.jl b/src/program_sparsity/program_sparsity.jl index 6c45c864..b9ab4568 100644 --- a/src/program_sparsity/program_sparsity.jl +++ b/src/program_sparsity/program_sparsity.jl @@ -58,5 +58,5 @@ function hsparsity(f, X, args...) map(arg -> arg isa Fixed ? arg.value : tag(arg, ctx, TermCombination([[]])), args)...) - metadata(val, ctx), untag(val, ctx) + Sparsity(metadata(val, ctx), length(X)) end diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index 544402a3..05a9ce55 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -25,3 +25,24 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) end Base.:*(comb1::TermCombination) = comb1 +function Sparsity(t::TermCombination, n) + I = Int[] + J = Int[] + for dict in t.terms + kv = collect(pairs(dict)) + for i in 1:length(kv) + k, v = kv[i] + if v>=2 + push!(I, k) + push!(J, k) + end + for j in i+1:length(kv) + if v >= 1 && kv[j][2] >= 1 + push!(I, k) + push!(J, kv[j][1]) + end + end + end + end + Sparsity(n,n,I,J) +end From 19d01d6740f9bf2009621888c3855ed21c9c28c3 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 01:40:11 -0400 Subject: [PATCH 03/22] optimize squaring --- src/program_sparsity/terms.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index 05a9ce55..be7601c7 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -19,9 +19,24 @@ function _merge(dict1, dict2) end function Base.:*(comb1::TermCombination, comb2::TermCombination) - vec([_merge(dict1, dict2) - for dict1 in comb1.terms, - dict2 in comb2.terms]) |> TermCombination + if comb1 === comb2 # squaring optimization + terms = comb1.terms + # square each term + t1 = [Dict(k=>2 for (k,_) in dict) + for dict in comb1.terms] + # multiply each term + t2 = Dict{Int,Int}[] + for i in 1:length(terms) + for j in i+1:length(terms) + push!(t2, _merge(terms[i], terms[j])) + end + end + TermCombination(vcat(t1, t2)) + else + vec([_merge(dict1, dict2) + for dict1 in comb1.terms, + dict2 in comb2.terms]) |> TermCombination + end end Base.:*(comb1::TermCombination) = comb1 From 516f616ba0a64781ecfbafe59dc19625ea3f361f Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 01:40:24 -0400 Subject: [PATCH 04/22] minor refactor --- src/program_sparsity/hessian.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index a6421d3c..20de78de 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -43,24 +43,23 @@ function combine_terms(::Val{linearity}, term1, term2) where linearity linear11, linear22, linear12 = linearity term = zero(TermCombination) if linear11 - term += term1 + if !linear12 + term += term1 + end else term += term1 * term1 end if linear22 - term += term2 + if !linear12 + term += term2 + end else term += term2 * term2 end if linear12 - if !linear11 - term += term1 - end - if !linear22 - term += term2 - end + term += term1 + term2 else term += term1 * term2 end From 55c74cb824830f0063c87e91837f8a024b5a9a40 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 01:40:59 -0400 Subject: [PATCH 05/22] make Hessian sparsity symmetric, change sparsity and hsparsity to return a SparseMatrix instead of Sparsity --- src/program_sparsity/program_sparsity.jl | 4 ++-- src/program_sparsity/terms.jl | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/program_sparsity/program_sparsity.jl b/src/program_sparsity/program_sparsity.jl index b9ab4568..a6fb087f 100644 --- a/src/program_sparsity/program_sparsity.jl +++ b/src/program_sparsity/program_sparsity.jl @@ -43,7 +43,7 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X))) alldone(path) && break reset!(path) end - sparsity + sparse(sparsity) end function hsparsity(f, X, args...) @@ -58,5 +58,5 @@ function hsparsity(f, X, args...) map(arg -> arg isa Fixed ? arg.value : tag(arg, ctx, TermCombination([[]])), args)...) - Sparsity(metadata(val, ctx), length(X)) + sparse(metadata(val, ctx), length(X)) end diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index be7601c7..c1804cad 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -40,7 +40,7 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) end Base.:*(comb1::TermCombination) = comb1 -function Sparsity(t::TermCombination, n) +function sparse(t::TermCombination, n) I = Int[] J = Int[] for dict in t.terms @@ -59,5 +59,6 @@ function Sparsity(t::TermCombination, n) end end end - Sparsity(n,n,I,J) + s1 = sparse(I,J,true,n,n) + s1 .| s1' end From 611e9e0d3aaf6913f7036c701132f8548172e53e Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 19:56:18 -0400 Subject: [PATCH 06/22] throw when a function of unknown linearity is called as a leaf-call --- src/program_sparsity/hessian.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index d4174045..c3fae277 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -59,27 +59,28 @@ end # Hessian overdub # -function getterms(ctx, x::Tagged) +function getterms(ctx, x) ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination) end function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) - if any(x->ismetatype(x, ctx, TermCombination), args) - t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) - val = Cassette.fallback(ctx, f, args...) - tag(val, ctx, t) - else - Cassette.recurse(ctx, f, args...) - end + t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) + val = Cassette.fallback(ctx, f, args...) + tag(val, ctx, t) end function Cassette.overdub(ctx::HessianSparsityContext, f, args...) - if haslinearity(f, Val{nfields(args)}()) + tainted = any(x->ismetatype(x, ctx, TermCombination), args) + if tainted && haslinearity(f, Val{nfields(args)}()) l = linearity(f, Val{nfields(args)}()) return hessian_overdub(ctx, f, l, args...) else - return Cassette.recurse(ctx, f, args...) + val = Cassette.recurse(ctx, f, args...) + if tainted && !ismetatype(val, ctx, TermCombination) + error("Don't know the linearity of function $f") + end + return val end end From ae9e4c6bb7d87ef0a2b0df2b2dd1fb7c2d042659 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 19:56:41 -0400 Subject: [PATCH 07/22] optimize squaring with an equivalent expression --- src/program_sparsity/terms.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index b9202403..61f77c89 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -13,7 +13,7 @@ Base.:+(comb1::TermCombination) = comb1 function _merge(dict1, dict2) d = copy(dict1) for (k, v) in dict2 - d[k] = get(dict1, k, 0) + v + d[k] = min(2, get(dict1, k, 0) + v) end d end @@ -21,6 +21,14 @@ end function Base.:*(comb1::TermCombination, comb2::TermCombination) if comb1 === comb2 # squaring optimization terms = comb1.terms + # turns out it's enough to track + # a^2*b^2 + # and a^2 + b^2 + ab + # have the same hessian sparsity + t = Dict(k=>2 for (k,_) in + Iterators.flatten(terms)) + TermCombination([t]) + #= # square each term t1 = [Dict(k=>2 for (k,_) in dict) for dict in comb1.terms] @@ -32,6 +40,7 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) end end TermCombination(vcat(t1, t2)) + =# else vec([_merge(dict1, dict2) for dict1 in comb1.terms, From ffa488e40847c1fa35af7ce3776167cb2e1525eb Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 19:57:31 -0400 Subject: [PATCH 08/22] overdub unsafe_copyto! --- src/program_sparsity/hessian.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index c3fae277..edf75908 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -23,6 +23,23 @@ function Cassette.overdub(ctx::HessianSparsityContext, end end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(Base.unsafe_copyto!), + X::Tagged, + xstart::Int, + Y::Tagged, + ystart::Int, + len::Int) + if ismetatype(Y, ctx, Input) + val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) + nometa = Cassette.NoMetaMeta() + X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination([Dict(i=>1)]), nometa)).(ystart:ystart+len-1) + val + else + Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) + end +end + # 1-arg functions combine_terms(::Val{true}, term) = term combine_terms(::Val{false}, term) = term * term From 7b1f9ca727087bce1392586bc31d265d67b8c5c2 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 20:10:39 -0400 Subject: [PATCH 09/22] add SpecialFunctions dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index c2fd76fc..60205040 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" From e9d87886b848cbe7178492bf68d128e09173ff54 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 20:32:06 -0400 Subject: [PATCH 10/22] take all branches -- treat comparison as linear --- src/program_sparsity/hessian.jl | 5 ++++ src/program_sparsity/linearity.jl | 2 +- src/program_sparsity/program_sparsity.jl | 30 ++++++++++++++--------- src/program_sparsity/take_all_branches.jl | 2 +- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index edf75908..980575a6 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -9,6 +9,11 @@ Cassette.@context HessianSparsityContext const HTagType = Union{Input, TermCombination} Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType +istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination) + +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x) +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata) + # getindex on the input function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(getindex), diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl index ce8db136..b24ae414 100644 --- a/src/program_sparsity/linearity.jl +++ b/src/program_sparsity/linearity.jl @@ -10,7 +10,7 @@ const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh # Val{(linear11, linear22, linear12)}() # # linearIJ refers to the zeroness of d^2/dxIxJ -diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -] +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal] diadic_of_linearity(::Val{(true, true, false)}) = [*] diadic_of_linearity(::Val{(true, false, true)}) = [] #diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] diff --git a/src/program_sparsity/program_sparsity.jl b/src/program_sparsity/program_sparsity.jl index 60c1c844..963a8cdc 100644 --- a/src/program_sparsity/program_sparsity.jl +++ b/src/program_sparsity/program_sparsity.jl @@ -41,17 +41,25 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)), sparse(sparsity) end -function hsparsity(f, X, args...) - ctx = HessianSparsityContext() - ctx = Cassette.enabletagging(ctx, f) - ctx = Cassette.disablehooks(ctx) +function hsparsity(f, X, args...; verbose=true) - val = Cassette.recurse(ctx, - f, - tag(X, ctx, Input()), - # TODO: make this recursive - map(arg -> arg isa Fixed ? - arg.value : tag(arg, ctx, TermCombination([[]])), args)...) + terms = zero(TermCombination) + path = Path() + while true + ctx = HessianSparsityContext(metadata=path, pass=BranchesPass) + ctx = Cassette.enabletagging(ctx, f) + ctx = Cassette.disablehooks(ctx) + val = Cassette.recurse(ctx, + f, + tag(X, ctx, Input()), + # TODO: make this recursive + map(arg -> arg isa Fixed ? + arg.value : tag(arg, ctx, one(TermCombination)), args)...) + terms += metadata(val, ctx) + verbose && println("Explored path: ", path) + alldone(path) && break + reset!(path) + end - _sparse(metadata(val, ctx), length(X)) + _sparse(terms, length(X)) end diff --git a/src/program_sparsity/take_all_branches.jl b/src/program_sparsity/take_all_branches.jl index b1853c24..701633e8 100644 --- a/src/program_sparsity/take_all_branches.jl +++ b/src/program_sparsity/take_all_branches.jl @@ -1,4 +1,4 @@ -istainted(ctx, x) = ismetatype(x, ctx, ProvinanceSet) +istainted(ctx::SparsityContext, x) = ismetatype(x, ctx, ProvinanceSet) Cassette.overdub(ctx::SparsityContext, f::typeof(istainted), x) = istainted(ctx, x) Cassette.overdub(ctx::SparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx) From f369d503cbbc8092ebd6372d9334b954b96362c7 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sat, 6 Jul 2019 20:35:44 -0400 Subject: [PATCH 11/22] make max and min linear --- src/program_sparsity/linearity.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl index b24ae414..60bdc18b 100644 --- a/src/program_sparsity/linearity.jl +++ b/src/program_sparsity/linearity.jl @@ -10,7 +10,7 @@ const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh # Val{(linear11, linear22, linear12)}() # # linearIJ refers to the zeroness of d^2/dxIxJ -diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal] +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min] diadic_of_linearity(::Val{(true, true, false)}) = [*] diadic_of_linearity(::Val{(true, false, true)}) = [] #diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] @@ -18,7 +18,7 @@ diadic_of_linearity(::Val{(true, false, false)}) = [/] diadic_of_linearity(::Val{(false, true, true)}) = [] diadic_of_linearity(::Val{(false, true, false)}) = [\] diadic_of_linearity(::Val{(false, false, true)}) = [] -diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, max, min, mod, rem, lbeta, ^, beta] +diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] haslinearity(f, nargs) = false for f in monadic_linear From 135dc9deb4cfaa3b21e22266dfcf47958ec6dfd7 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 00:45:02 -0400 Subject: [PATCH 12/22] more linearity info, getting broadcast to work --- src/program_sparsity/hessian.jl | 15 +++++++++++++++ src/program_sparsity/linearity.jl | 31 ++++++++++++++++++++++--------- src/program_sparsity/terms.jl | 2 ++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index 980575a6..f6e15480 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -45,6 +45,8 @@ function Cassette.overdub(ctx::HessianSparsityContext, end end +combine_terms(::Nothing, terms...) = one(TermCombination) + # 1-arg functions combine_terms(::Val{true}, term) = term combine_terms(::Val{false}, term) = term * term @@ -90,10 +92,23 @@ function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) val = Cassette.fallback(ctx, f, args...) tag(val, ctx, t) end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getproperty), + x::Tagged, prop) + if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx)) + Cassette.fallback(ctx, f, x, prop) + error("property of a non-constant term accessed") + else + Cassette.fallback(ctx, f, x, prop) + end +end function Cassette.overdub(ctx::HessianSparsityContext, f, args...) + if length(args) > 2 + return Cassette.recurse(ctx, f, args...) + end tainted = any(x->ismetatype(x, ctx, TermCombination), args) if tainted && haslinearity(f, Val{nfields(args)}()) l = linearity(f, Val{nfields(args)}()) diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl index 60bdc18b..14b4ad91 100644 --- a/src/program_sparsity/linearity.jl +++ b/src/program_sparsity/linearity.jl @@ -1,16 +1,12 @@ using SpecialFunctions +import Base.Broadcast -# linearity of a single input function is either -# Val{true}() or Val{false}() -# -const monadic_linear = [deg2rad, +, rad2deg, transpose, -] +const constant_funcs = [typeof, Broadcast.combine_styles, Broadcast.result_style] + +const monadic_linear = [deg2rad, +, rad2deg, transpose, -, Base.broadcasted] const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] -# linearity of a 2-arg function is: -# Val{(linear11, linear22, linear12)}() -# -# linearIJ refers to the zeroness of d^2/dxIxJ -diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min] +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert, conj, Broadcast.broadcasted] diadic_of_linearity(::Val{(true, true, false)}) = [*] diadic_of_linearity(::Val{(true, false, true)}) = [] #diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] @@ -21,12 +17,29 @@ diadic_of_linearity(::Val{(false, false, true)}) = [] diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] haslinearity(f, nargs) = false + +# some functions strip the linearity metadata + +for f in constant_funcs + @eval begin + haslinearity(::typeof($f), ::Val) = true + linearity(::typeof($f), ::Val) = nothing + end +end + +# linearity of a single input function is either +# Val{true}() or Val{false}() +# for f in monadic_linear @eval begin haslinearity(::typeof($f), ::Val{1}) = true linearity(::typeof($f), ::Val{1}) = Val{true}() end end +# linearity of a 2-arg function is: +# Val{(linear11, linear22, linear12)}() +# +# linearIJ refers to the zeroness of d^2/dxIxJ for f in monadic_nonlinear @eval begin haslinearity(::typeof($f), ::Val{1}) = true diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index 61f77c89..9d19a6de 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -48,6 +48,8 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) end end Base.:*(comb1::TermCombination) = comb1 +Base.iszero(c::TermCombination) = isempty(c.terms) +Base.isone(c::TermCombination) = all(isempty, c.terms) function _sparse(t::TermCombination, n) I = Int[] From e3cf2629bcbb763ce18373f78af5ac8f7a2b148b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:41:46 -0400 Subject: [PATCH 13/22] handle getindex when index is tagged -- fixes iterate with state --- src/program_sparsity/hessian.jl | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index f6e15480..4533587e 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -6,6 +6,8 @@ using SparseArrays # Tags: Cassette.@context HessianSparsityContext +const TaggedOf{T} = Tagged{A, T} where A + const HTagType = Union{Input, TermCombination} Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType @@ -18,10 +20,19 @@ Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(getindex), X::Tagged, - idx::Int...) + idx::Tagged...) + if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx) + error("getindex call depends on input. Cannot determine Hessian sparsity") + end + Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) +end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged, + idx::Integer...) if ismetatype(X, ctx, Input) - i = LinearIndices(untag(X, ctx))[idx...] val = Cassette.fallback(ctx, f, X, idx...) + i = LinearIndices(untag(X, ctx))[idx...] tag(val, ctx, TermCombination([Dict(i=>1)])) else Cassette.recurse(ctx, f, X, idx...) @@ -31,10 +42,10 @@ end function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(Base.unsafe_copyto!), X::Tagged, - xstart::Int, + xstart, Y::Tagged, - ystart::Int, - len::Int) + ystart, + len) if ismetatype(Y, ctx, Input) val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) nometa = Cassette.NoMetaMeta() From ad1463d99e37c3bf3b917ca9e234b1875e3d4a0e Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:42:14 -0400 Subject: [PATCH 14/22] overdub copy on input --- src/program_sparsity/hessian.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index 4533587e..fcd6b3b1 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -55,6 +55,16 @@ function Cassette.overdub(ctx::HessianSparsityContext, Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) end end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(copy), + X::Tagged) + if ismetatype(X, ctx, Input) + val = Cassette.fallback(ctx, f, X) + tag(val, ctx, Input()) + else + Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) + end +end combine_terms(::Nothing, terms...) = one(TermCombination) From a11d95fa9618a447ee710f3d96c8ceba436e532b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:42:49 -0400 Subject: [PATCH 15/22] the function itself maybe tagged, peel the tag off before checking linearity --- src/program_sparsity/hessian.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index fcd6b3b1..ac8e46ca 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -117,28 +117,28 @@ function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(getproperty), x::Tagged, prop) if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx)) - Cassette.fallback(ctx, f, x, prop) error("property of a non-constant term accessed") else - Cassette.fallback(ctx, f, x, prop) + Cassette.recurse(ctx, f, x, prop) end end +haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs) +linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs) + function Cassette.overdub(ctx::HessianSparsityContext, f, args...) - if length(args) > 2 - return Cassette.recurse(ctx, f, args...) - end tainted = any(x->ismetatype(x, ctx, TermCombination), args) - if tainted && haslinearity(f, Val{nfields(args)}()) - l = linearity(f, Val{nfields(args)}()) - return hessian_overdub(ctx, f, l, args...) + val = if tainted && haslinearity(ctx, f, Val{nfields(args)}()) + l = linearity(ctx, f, Val{nfields(args)}()) + hessian_overdub(ctx, f, l, args...) else val = Cassette.recurse(ctx, f, args...) - if tainted && !ismetatype(val, ctx, TermCombination) - error("Don't know the linearity of function $f") - end - return val + #if tainted && !ismetatype(val, ctx, TermCombination) + # @warn("Don't know the linearity of function $f") + #end + val end + val end From 78c19e42a3923cb01eeff716b786e9f9832f0348 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:43:29 -0400 Subject: [PATCH 16/22] revert some frantic linearity mess --- src/program_sparsity/linearity.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl index 14b4ad91..f1d9ffb3 100644 --- a/src/program_sparsity/linearity.jl +++ b/src/program_sparsity/linearity.jl @@ -1,20 +1,19 @@ using SpecialFunctions import Base.Broadcast -const constant_funcs = [typeof, Broadcast.combine_styles, Broadcast.result_style] +const constant_funcs = [] + +const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj] -const monadic_linear = [deg2rad, +, rad2deg, transpose, -, Base.broadcasted] const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] -diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert, conj, Broadcast.broadcasted] +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert] diadic_of_linearity(::Val{(true, true, false)}) = [*] -diadic_of_linearity(::Val{(true, false, true)}) = [] #diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] diadic_of_linearity(::Val{(true, false, false)}) = [/] -diadic_of_linearity(::Val{(false, true, true)}) = [] diadic_of_linearity(::Val{(false, true, false)}) = [\] -diadic_of_linearity(::Val{(false, false, true)}) = [] diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] +diadic_of_linearity(::Val) = [] haslinearity(f, nargs) = false From 536d10a23e2c02e969afc9c53c106f1c544cee46 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:43:48 -0400 Subject: [PATCH 17/22] hessian test helpers --- test/program_sparsity/common.jl | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/program_sparsity/common.jl b/test/program_sparsity/common.jl index eb9b0176..61de55a0 100644 --- a/test/program_sparsity/common.jl +++ b/test/program_sparsity/common.jl @@ -5,7 +5,7 @@ import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged using SparseDiffTools import SparseDiffTools: Path, BranchesPass, SparsityContext, Fixed, Input, Output, ProvinanceSet, Tainted, istainted, - alldone, reset! + alldone, reset!, HessianSparsityContext function tester(f, Y, X, args...; sparsity=Sparsity(length(Y), length(X))) @@ -34,6 +34,32 @@ testmeta(args...) = tester(args...)[1].metadata testval(args...) = tester(args...) |> ((ctx,val),) -> untag(val, ctx) testtag(args...) = tester(args...) |> ((ctx,val),) -> metadata(val, ctx) +function htester(f, X, args...) + + path = Path() + ctx = HessianSparsityContext(metadata=path, pass=BranchesPass) + ctx = Cassette.enabletagging(ctx, f) + ctx = Cassette.disablehooks(ctx) + + val = nothing + while true + val = Cassette.overdub(ctx, + f, + tag(X, ctx, Input()), + map(arg -> arg isa Fixed ? + arg.value : + tag(arg, ctx, one(TermCombination)), args)...) + println("Explored path: ", path) + alldone(path) && break + reset!(path) + end + return ctx, val +end +htestmeta(args...) = htester(args...)[1].metadata +htestval(args...) = htester(args...) |> ((ctx,val),) -> untag(val, ctx) +htesttag(args...) = htester(args...) |> ((ctx,val),) -> metadata(val, ctx) + + using Test Base.show(io::IO, ::Type{<:Cassette.Context}) = print(io, "ctx") From 6c42074003709ca016c8dd9060d9332f58733fd4 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 16:44:33 -0400 Subject: [PATCH 18/22] reroute blas calls to generic methods --- src/SparseDiffTools.jl | 1 + src/program_sparsity/blas.jl | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/program_sparsity/blas.jl diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 300987a4..f60af7cd 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -43,5 +43,6 @@ include("program_sparsity/take_all_branches.jl") include("program_sparsity/terms.jl") include("program_sparsity/linearity.jl") include("program_sparsity/hessian.jl") +include("program_sparsity/blas.jl") end # module diff --git a/src/program_sparsity/blas.jl b/src/program_sparsity/blas.jl new file mode 100644 index 00000000..c031fb50 --- /dev/null +++ b/src/program_sparsity/blas.jl @@ -0,0 +1,25 @@ +using LinearAlgebra +import LinearAlgebra.BLAS + +# generic implementations + +macro reroute(f, g) + quote + function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof($(esc(f))), + args...) + println("rerouted") + Cassette.overdub( + ctx, + invoke, + $(esc(g.args[1])), + $(esc(:(Tuple{$(g.args[2:end]...)}))), + args...) + end + end +end + +@reroute BLAS.dot dot(Any, Any) +@reroute BLAS.axpy! axpy!(Any, + AbstractArray, + AbstractArray) From e36801f177b45e2855196aee7c6fd9573652401c Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 21:28:28 -0400 Subject: [PATCH 19/22] use Set of Dicts for terms --- src/program_sparsity/hessian.jl | 4 ++-- src/program_sparsity/terms.jl | 42 +++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index ac8e46ca..74a885e7 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -33,7 +33,7 @@ function Cassette.overdub(ctx::HessianSparsityContext, if ismetatype(X, ctx, Input) val = Cassette.fallback(ctx, f, X, idx...) i = LinearIndices(untag(X, ctx))[idx...] - tag(val, ctx, TermCombination([Dict(i=>1)])) + tag(val, ctx, TermCombination(Set([Dict(i=>1)]))) else Cassette.recurse(ctx, f, X, idx...) end @@ -49,7 +49,7 @@ function Cassette.overdub(ctx::HessianSparsityContext, if ismetatype(Y, ctx, Input) val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) nometa = Cassette.NoMetaMeta() - X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination([Dict(i=>1)]), nometa)).(ystart:ystart+len-1) + X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1) val else Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl index 9d19a6de..f5517633 100644 --- a/src/program_sparsity/terms.jl +++ b/src/program_sparsity/terms.jl @@ -1,13 +1,31 @@ struct TermCombination - terms::Vector{Dict{Int, Int}} # idx => pow + terms::Set{Dict{Int, Int}} # idx => pow end -Base.zero(::Type{TermCombination}) = TermCombination([]) -Base.one(::Type{TermCombination}) = TermCombination([Dict{Int,Int}()]) +@eval Base.zero(::Type{TermCombination}) = $(TermCombination(Set{Dict{Int,Int}}())) +@eval Base.one(::Type{TermCombination}) = $(TermCombination(Set([Dict{Int,Int}()]))) + +function Base.:(==)(comb1::TermCombination, comb2::TermCombination) + comb1.terms == comb2.terms && return true + + n1 = reduce(max, (k for (k,_) in Iterators.flatten(comb1.terms)), init=0) + n2 = reduce(max, (k for (k,_) in Iterators.flatten(comb2.terms)), init=0) + n = max(n1, n2) + + _sparse(comb1, n) == _sparse(comb2, n) +end function Base.:+(comb1::TermCombination, comb2::TermCombination) - TermCombination(vcat(comb1.terms, comb2.terms)) + if isone(comb1) && !iszero(comb2) + return comb2 + elseif isone(comb2) && !iszero(comb1) + return comb1 + elseif comb1 === comb2 + return comb1 + end + TermCombination(union(comb1.terms, comb2.terms)) end + Base.:+(comb1::TermCombination) = comb1 function _merge(dict1, dict2) @@ -19,7 +37,11 @@ function _merge(dict1, dict2) end function Base.:*(comb1::TermCombination, comb2::TermCombination) - if comb1 === comb2 # squaring optimization + if isone(comb1) + return comb2 + elseif isone(comb2) + return comb1 + elseif comb1 === comb2 # squaring optimization terms = comb1.terms # turns out it's enough to track # a^2*b^2 @@ -27,7 +49,7 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) # have the same hessian sparsity t = Dict(k=>2 for (k,_) in Iterators.flatten(terms)) - TermCombination([t]) + TermCombination(Set([t])) #= # square each term t1 = [Dict(k=>2 for (k,_) in dict) @@ -39,12 +61,12 @@ function Base.:*(comb1::TermCombination, comb2::TermCombination) push!(t2, _merge(terms[i], terms[j])) end end - TermCombination(vcat(t1, t2)) + TermCombination(union(t1, t2)) =# else - vec([_merge(dict1, dict2) - for dict1 in comb1.terms, - dict2 in comb2.terms]) |> TermCombination + Set([_merge(dict1, dict2) + for dict1 in comb1.terms, + dict2 in comb2.terms]) |> TermCombination end end Base.:*(comb1::TermCombination) = comb1 From 1fb80d6caf10bad294f968f48c1ab694a3fc1334 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 21:29:37 -0400 Subject: [PATCH 20/22] fix a wrecked branch --- src/program_sparsity/hessian.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index 74a885e7..f708d990 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -26,6 +26,7 @@ function Cassette.overdub(ctx::HessianSparsityContext, end Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) end + function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(getindex), X::Tagged, @@ -62,7 +63,7 @@ function Cassette.overdub(ctx::HessianSparsityContext, val = Cassette.fallback(ctx, f, X) tag(val, ctx, Input()) else - Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) + Cassette.recurse(ctx, f, X) end end From dab5ba6edacf92a4bee3a7d69a0fc482e7ade7f5 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 21:30:05 -0400 Subject: [PATCH 21/22] add hessian sparsity detection tests --- test/program_sparsity/hessian.jl | 64 +++++++++++++++++++++++++++++ test/program_sparsity/paraboloid.jl | 26 ++++++++++++ test/program_sparsity/testall.jl | 2 + test/runtests.jl | 2 +- 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 test/program_sparsity/hessian.jl create mode 100644 test/program_sparsity/paraboloid.jl diff --git a/test/program_sparsity/hessian.jl b/test/program_sparsity/hessian.jl new file mode 100644 index 00000000..71371d07 --- /dev/null +++ b/test/program_sparsity/hessian.jl @@ -0,0 +1,64 @@ +import SparseDiffTools: TermCombination +using Test + +Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)])) + +@test htesttag(x->x, [1,2]) == Input() +@test htesttag(x->x[1], [1,2]) == Term(1) + +# Tuple / struct +@test htesttag(x->(x[1],x[2])[2], [1,2]) == Term(2) + +# 1-arg linear +@test htesttag(x->deg2rad(x[1]), [1,2]) == Term(1) + +# 1-arg nonlinear +@test htesttag(x->sin(x[1]), [1,2]) == (Term(1) * Term(1)) + +# 2-arg (true,true,true) +@test htesttag(x->x[1]+x[2], [1,2]) == Term(1)+Term(2) + +# 2-arg (true,true, false) +@test htesttag(x->x[1]*x[2], [1,2]) == Term(1)*Term(2) + +# 2-arg (true,false,false) +@test htesttag(x->x[1]/x[2], [1,2]) == Term(1)*Term(2)*Term(2) + +# 2-arg (false,true,false) +@test htesttag(x->x[1]\x[2], [1,2]) == Term(1)*Term(1)*Term(2) + +# 2-arg (false,false,false) +@test htesttag(x->hypot(x[1], x[2]), [1,2]) == (Term(1) + Term(2)) * (Term(1) + Term(2)) + + +### Array operations + +# copy +@test htesttag(x->copy(x)[1], [1,2]) == Term(1) +@test htesttag(x->x[:][1], [1,2]) == Term(1) +@test htesttag(x->x[1:1][1], [1,2]) == Term(1) + +# tests `iterate` +function mysum(x) + s = 0 + for a in x + s += a + end + s +end +@test htesttag(mysum, [1,2]).terms == (Term(1) + Term(2)).terms +@test htesttag(mysum, [1,2.]).terms == (Term(1) + Term(2)).terms + +using LinearAlgebra + +# integer dot product falls back to generic +@test htesttag(x->dot(x,x), [1,2,3]) == sum(Term(i)*Term(i) for i=1:3) + +# reroutes to generic implementation (blas.jl) +@test htesttag(x->dot(x,x), [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) +@test htesttag(x->x'x, [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) + +# broadcast +@test htesttag(x->sum(x[1] .+ x[2]), [1,2,3.]) == Term(1) + Term(2) +@test htesttag(x->sum(x .+ x), [1,2,3.]) == sum(Term(i) for i=1:3) +@test htesttag(x->sum(x .* x), [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) diff --git a/test/program_sparsity/paraboloid.jl b/test/program_sparsity/paraboloid.jl new file mode 100644 index 00000000..2844c627 --- /dev/null +++ b/test/program_sparsity/paraboloid.jl @@ -0,0 +1,26 @@ +using SparseArrays +using LinearAlgebra +using SparseDiffTools + +struct ParaboloidStruct{T, Tm <: AbstractArray{T,2}, + Tv <: AbstractArray{T}} <: Any where T<:Number + mat::Tm + vec::Tv + xt::Tv + alpha::T +end + +function quad(x::Vector, param) + mat = param.mat + xt = x-param.vec + return 0.5*dot(xt, mat*xt) +end + +function _paraboloidproblem(N::Int; + mat::AbstractArray{T,2} = sparse(Diagonal(float(1:N))), + alpha::T=10.0, + x0::AbstractVector{T} = ones(N)) where T <: Number + hsparsity(quad,x0,ParaboloidStruct(mat, x0, similar(x0), alpha)) +end + +@test isdiag(_paraboloidproblem(10)) diff --git a/test/program_sparsity/testall.jl b/test/program_sparsity/testall.jl index ba746ac3..c26827bc 100644 --- a/test/program_sparsity/testall.jl +++ b/test/program_sparsity/testall.jl @@ -2,3 +2,5 @@ include("common.jl") @testset "Basics" begin include("basics.jl") end @testset "Exploration" begin include("ifsandbuts.jl") end +@testset "Hessian sparsity" begin include("hessian.jl") end +@testset "Paraboloid example" begin include("paraboloid.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index 2577c9a5..fd0df5d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,4 @@ using Test @testset "Integration test" begin include("test_integration.jl") end @testset "Special matrices" begin include("test_specialmatrices.jl") end @testset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end -@testset "Jacobian sparsity computation" begin include("program_sparsity/testall.jl") end \ No newline at end of file +@testset "Program sparsity computation" begin include("program_sparsity/testall.jl") end From 58c2402be2b92a4d76c6abac0e1f4bc1bd0424e1 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Sun, 7 Jul 2019 21:38:44 -0400 Subject: [PATCH 22/22] fix ambiguity for x[] where x is tagged --- src/program_sparsity/hessian.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl index f708d990..8d5734b9 100644 --- a/src/program_sparsity/hessian.jl +++ b/src/program_sparsity/hessian.jl @@ -27,6 +27,13 @@ function Cassette.overdub(ctx::HessianSparsityContext, Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) end +# plugs an ambiguity +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged) + Cassette.recurse(ctx, f, X) +end + function Cassette.overdub(ctx::HessianSparsityContext, f::typeof(getindex), X::Tagged,