diff --git a/Project.toml b/Project.toml index c2fd76fc..60205040 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index b2c60b7b..f60af7cd 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -27,7 +27,7 @@ export contract_color, numback_hesvec,numback_hesvec!, autoback_hesvec,autoback_hesvec!, JacVec,HesVec,HesVecGrad, - Sparsity, sparsity! + Sparsity, sparsity!, hsparsity include("coloring/high_level.jl") @@ -40,5 +40,9 @@ include("program_sparsity/program_sparsity.jl") include("program_sparsity/sparsity_tracker.jl") include("program_sparsity/path.jl") include("program_sparsity/take_all_branches.jl") +include("program_sparsity/terms.jl") +include("program_sparsity/linearity.jl") +include("program_sparsity/hessian.jl") +include("program_sparsity/blas.jl") end # module diff --git a/src/program_sparsity/blas.jl b/src/program_sparsity/blas.jl new file mode 100644 index 00000000..c031fb50 --- /dev/null +++ b/src/program_sparsity/blas.jl @@ -0,0 +1,25 @@ +using LinearAlgebra +import LinearAlgebra.BLAS + +# generic implementations + +macro reroute(f, g) + quote + function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof($(esc(f))), + args...) + println("rerouted") + Cassette.overdub( + ctx, + invoke, + $(esc(g.args[1])), + $(esc(:(Tuple{$(g.args[2:end]...)}))), + args...) + end + end +end + +@reroute BLAS.dot dot(Any, Any) +@reroute BLAS.axpy! axpy!(Any, + AbstractArray, + AbstractArray) diff --git a/src/program_sparsity/hessian.jl b/src/program_sparsity/hessian.jl new file mode 100644 index 00000000..8d5734b9 --- /dev/null +++ b/src/program_sparsity/hessian.jl @@ -0,0 +1,152 @@ +using Cassette +import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse +import Core: SSAValue +using SparseArrays + +# Tags: +Cassette.@context HessianSparsityContext + +const TaggedOf{T} = Tagged{A, T} where A + +const HTagType = Union{Input, TermCombination} +Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType + +istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination) + +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x) +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata) + +# getindex on the input +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged, + idx::Tagged...) + if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx) + error("getindex call depends on input. Cannot determine Hessian sparsity") + end + Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) +end + +# plugs an ambiguity +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged) + Cassette.recurse(ctx, f, X) +end + +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getindex), + X::Tagged, + idx::Integer...) + if ismetatype(X, ctx, Input) + val = Cassette.fallback(ctx, f, X, idx...) + i = LinearIndices(untag(X, ctx))[idx...] + tag(val, ctx, TermCombination(Set([Dict(i=>1)]))) + else + Cassette.recurse(ctx, f, X, idx...) + end +end + +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(Base.unsafe_copyto!), + X::Tagged, + xstart, + Y::Tagged, + ystart, + len) + if ismetatype(Y, ctx, Input) + val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) + nometa = Cassette.NoMetaMeta() + X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1) + val + else + Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) + end +end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(copy), + X::Tagged) + if ismetatype(X, ctx, Input) + val = Cassette.fallback(ctx, f, X) + tag(val, ctx, Input()) + else + Cassette.recurse(ctx, f, X) + end +end + +combine_terms(::Nothing, terms...) = one(TermCombination) + +# 1-arg functions +combine_terms(::Val{true}, term) = term +combine_terms(::Val{false}, term) = term * term + +# 2-arg functions +function combine_terms(::Val{linearity}, term1, term2) where linearity + + linear11, linear22, linear12 = linearity + term = zero(TermCombination) + if linear11 + if !linear12 + term += term1 + end + else + term += term1 * term1 + end + + if linear22 + if !linear12 + term += term2 + end + else + term += term2 * term2 + end + + if linear12 + term += term1 + term2 + else + term += term1 * term2 + end + term +end + + +# Hessian overdub +# +function getterms(ctx, x) + ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination) +end + +function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) + t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) + val = Cassette.fallback(ctx, f, args...) + tag(val, ctx, t) +end +function Cassette.overdub(ctx::HessianSparsityContext, + f::typeof(getproperty), + x::Tagged, prop) + if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx)) + error("property of a non-constant term accessed") + else + Cassette.recurse(ctx, f, x, prop) + end +end + +haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs) +linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs) + +function Cassette.overdub(ctx::HessianSparsityContext, + f, + args...) + tainted = any(x->ismetatype(x, ctx, TermCombination), args) + val = if tainted && haslinearity(ctx, f, Val{nfields(args)}()) + l = linearity(ctx, f, Val{nfields(args)}()) + hessian_overdub(ctx, f, l, args...) + else + val = Cassette.recurse(ctx, f, args...) + #if tainted && !ismetatype(val, ctx, TermCombination) + # @warn("Don't know the linearity of function $f") + #end + val + end + val +end diff --git a/src/program_sparsity/linearity.jl b/src/program_sparsity/linearity.jl new file mode 100644 index 00000000..f1d9ffb3 --- /dev/null +++ b/src/program_sparsity/linearity.jl @@ -0,0 +1,60 @@ +using SpecialFunctions +import Base.Broadcast + +const constant_funcs = [] + +const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj] + +const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh] + +diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert] +diadic_of_linearity(::Val{(true, true, false)}) = [*] +#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1] +diadic_of_linearity(::Val{(true, false, false)}) = [/] +diadic_of_linearity(::Val{(false, true, false)}) = [\] +diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta] +diadic_of_linearity(::Val) = [] + +haslinearity(f, nargs) = false + +# some functions strip the linearity metadata + +for f in constant_funcs + @eval begin + haslinearity(::typeof($f), ::Val) = true + linearity(::typeof($f), ::Val) = nothing + end +end + +# linearity of a single input function is either +# Val{true}() or Val{false}() +# +for f in monadic_linear + @eval begin + haslinearity(::typeof($f), ::Val{1}) = true + linearity(::typeof($f), ::Val{1}) = Val{true}() + end +end +# linearity of a 2-arg function is: +# Val{(linear11, linear22, linear12)}() +# +# linearIJ refers to the zeroness of d^2/dxIxJ +for f in monadic_nonlinear + @eval begin + haslinearity(::typeof($f), ::Val{1}) = true + linearity(::typeof($f), ::Val{1}) = Val{false}() + end +end + +for linearity_mask = 0:2^3-1 + lin = Val{map(x->x!=0, (linearity_mask & 4, + linearity_mask & 2, + linearity_mask & 1))}() + + for f in diadic_of_linearity(lin) + @eval begin + haslinearity(::typeof($f), ::Val{2}) = true + linearity(::typeof($f), ::Val{2}) = $lin + end + end +end diff --git a/src/program_sparsity/program_sparsity.jl b/src/program_sparsity/program_sparsity.jl index 7b8bab17..963a8cdc 100644 --- a/src/program_sparsity/program_sparsity.jl +++ b/src/program_sparsity/program_sparsity.jl @@ -38,5 +38,28 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)), alldone(path) && break reset!(path) end - sparsity + sparse(sparsity) +end + +function hsparsity(f, X, args...; verbose=true) + + terms = zero(TermCombination) + path = Path() + while true + ctx = HessianSparsityContext(metadata=path, pass=BranchesPass) + ctx = Cassette.enabletagging(ctx, f) + ctx = Cassette.disablehooks(ctx) + val = Cassette.recurse(ctx, + f, + tag(X, ctx, Input()), + # TODO: make this recursive + map(arg -> arg isa Fixed ? + arg.value : tag(arg, ctx, one(TermCombination)), args)...) + terms += metadata(val, ctx) + verbose && println("Explored path: ", path) + alldone(path) && break + reset!(path) + end + + _sparse(terms, length(X)) end diff --git a/src/program_sparsity/take_all_branches.jl b/src/program_sparsity/take_all_branches.jl index b1853c24..701633e8 100644 --- a/src/program_sparsity/take_all_branches.jl +++ b/src/program_sparsity/take_all_branches.jl @@ -1,4 +1,4 @@ -istainted(ctx, x) = ismetatype(x, ctx, ProvinanceSet) +istainted(ctx::SparsityContext, x) = ismetatype(x, ctx, ProvinanceSet) Cassette.overdub(ctx::SparsityContext, f::typeof(istainted), x) = istainted(ctx, x) Cassette.overdub(ctx::SparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx) diff --git a/src/program_sparsity/terms.jl b/src/program_sparsity/terms.jl new file mode 100644 index 00000000..f5517633 --- /dev/null +++ b/src/program_sparsity/terms.jl @@ -0,0 +1,97 @@ +struct TermCombination + terms::Set{Dict{Int, Int}} # idx => pow +end + +@eval Base.zero(::Type{TermCombination}) = $(TermCombination(Set{Dict{Int,Int}}())) +@eval Base.one(::Type{TermCombination}) = $(TermCombination(Set([Dict{Int,Int}()]))) + +function Base.:(==)(comb1::TermCombination, comb2::TermCombination) + comb1.terms == comb2.terms && return true + + n1 = reduce(max, (k for (k,_) in Iterators.flatten(comb1.terms)), init=0) + n2 = reduce(max, (k for (k,_) in Iterators.flatten(comb2.terms)), init=0) + n = max(n1, n2) + + _sparse(comb1, n) == _sparse(comb2, n) +end + +function Base.:+(comb1::TermCombination, comb2::TermCombination) + if isone(comb1) && !iszero(comb2) + return comb2 + elseif isone(comb2) && !iszero(comb1) + return comb1 + elseif comb1 === comb2 + return comb1 + end + TermCombination(union(comb1.terms, comb2.terms)) +end + +Base.:+(comb1::TermCombination) = comb1 + +function _merge(dict1, dict2) + d = copy(dict1) + for (k, v) in dict2 + d[k] = min(2, get(dict1, k, 0) + v) + end + d +end + +function Base.:*(comb1::TermCombination, comb2::TermCombination) + if isone(comb1) + return comb2 + elseif isone(comb2) + return comb1 + elseif comb1 === comb2 # squaring optimization + terms = comb1.terms + # turns out it's enough to track + # a^2*b^2 + # and a^2 + b^2 + ab + # have the same hessian sparsity + t = Dict(k=>2 for (k,_) in + Iterators.flatten(terms)) + TermCombination(Set([t])) + #= + # square each term + t1 = [Dict(k=>2 for (k,_) in dict) + for dict in comb1.terms] + # multiply each term + t2 = Dict{Int,Int}[] + for i in 1:length(terms) + for j in i+1:length(terms) + push!(t2, _merge(terms[i], terms[j])) + end + end + TermCombination(union(t1, t2)) + =# + else + Set([_merge(dict1, dict2) + for dict1 in comb1.terms, + dict2 in comb2.terms]) |> TermCombination + end +end +Base.:*(comb1::TermCombination) = comb1 +Base.iszero(c::TermCombination) = isempty(c.terms) +Base.isone(c::TermCombination) = all(isempty, c.terms) + +function _sparse(t::TermCombination, n) + I = Int[] + J = Int[] + for dict in t.terms + kv = collect(pairs(dict)) + for i in 1:length(kv) + k, v = kv[i] + if v>=2 + push!(I, k) + push!(J, k) + end + for j in i+1:length(kv) + if v >= 1 && kv[j][2] >= 1 + push!(I, k) + push!(J, kv[j][1]) + end + end + end + end + s1 = sparse(I,J,fill!(BitVector(undef, length(I)), true),n,n) + s1 .| s1' +end diff --git a/test/program_sparsity/common.jl b/test/program_sparsity/common.jl index eb9b0176..61de55a0 100644 --- a/test/program_sparsity/common.jl +++ b/test/program_sparsity/common.jl @@ -5,7 +5,7 @@ import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged using SparseDiffTools import SparseDiffTools: Path, BranchesPass, SparsityContext, Fixed, Input, Output, ProvinanceSet, Tainted, istainted, - alldone, reset! + alldone, reset!, HessianSparsityContext function tester(f, Y, X, args...; sparsity=Sparsity(length(Y), length(X))) @@ -34,6 +34,32 @@ testmeta(args...) = tester(args...)[1].metadata testval(args...) = tester(args...) |> ((ctx,val),) -> untag(val, ctx) testtag(args...) = tester(args...) |> ((ctx,val),) -> metadata(val, ctx) +function htester(f, X, args...) + + path = Path() + ctx = HessianSparsityContext(metadata=path, pass=BranchesPass) + ctx = Cassette.enabletagging(ctx, f) + ctx = Cassette.disablehooks(ctx) + + val = nothing + while true + val = Cassette.overdub(ctx, + f, + tag(X, ctx, Input()), + map(arg -> arg isa Fixed ? + arg.value : + tag(arg, ctx, one(TermCombination)), args)...) + println("Explored path: ", path) + alldone(path) && break + reset!(path) + end + return ctx, val +end +htestmeta(args...) = htester(args...)[1].metadata +htestval(args...) = htester(args...) |> ((ctx,val),) -> untag(val, ctx) +htesttag(args...) = htester(args...) |> ((ctx,val),) -> metadata(val, ctx) + + using Test Base.show(io::IO, ::Type{<:Cassette.Context}) = print(io, "ctx") diff --git a/test/program_sparsity/hessian.jl b/test/program_sparsity/hessian.jl new file mode 100644 index 00000000..71371d07 --- /dev/null +++ b/test/program_sparsity/hessian.jl @@ -0,0 +1,64 @@ +import SparseDiffTools: TermCombination +using Test + +Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)])) + +@test htesttag(x->x, [1,2]) == Input() +@test htesttag(x->x[1], [1,2]) == Term(1) + +# Tuple / struct +@test htesttag(x->(x[1],x[2])[2], [1,2]) == Term(2) + +# 1-arg linear +@test htesttag(x->deg2rad(x[1]), [1,2]) == Term(1) + +# 1-arg nonlinear +@test htesttag(x->sin(x[1]), [1,2]) == (Term(1) * Term(1)) + +# 2-arg (true,true,true) +@test htesttag(x->x[1]+x[2], [1,2]) == Term(1)+Term(2) + +# 2-arg (true,true, false) +@test htesttag(x->x[1]*x[2], [1,2]) == Term(1)*Term(2) + +# 2-arg (true,false,false) +@test htesttag(x->x[1]/x[2], [1,2]) == Term(1)*Term(2)*Term(2) + +# 2-arg (false,true,false) +@test htesttag(x->x[1]\x[2], [1,2]) == Term(1)*Term(1)*Term(2) + +# 2-arg (false,false,false) +@test htesttag(x->hypot(x[1], x[2]), [1,2]) == (Term(1) + Term(2)) * (Term(1) + Term(2)) + + +### Array operations + +# copy +@test htesttag(x->copy(x)[1], [1,2]) == Term(1) +@test htesttag(x->x[:][1], [1,2]) == Term(1) +@test htesttag(x->x[1:1][1], [1,2]) == Term(1) + +# tests `iterate` +function mysum(x) + s = 0 + for a in x + s += a + end + s +end +@test htesttag(mysum, [1,2]).terms == (Term(1) + Term(2)).terms +@test htesttag(mysum, [1,2.]).terms == (Term(1) + Term(2)).terms + +using LinearAlgebra + +# integer dot product falls back to generic +@test htesttag(x->dot(x,x), [1,2,3]) == sum(Term(i)*Term(i) for i=1:3) + +# reroutes to generic implementation (blas.jl) +@test htesttag(x->dot(x,x), [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) +@test htesttag(x->x'x, [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) + +# broadcast +@test htesttag(x->sum(x[1] .+ x[2]), [1,2,3.]) == Term(1) + Term(2) +@test htesttag(x->sum(x .+ x), [1,2,3.]) == sum(Term(i) for i=1:3) +@test htesttag(x->sum(x .* x), [1,2,3.]) == sum(Term(i)*Term(i) for i=1:3) diff --git a/test/program_sparsity/paraboloid.jl b/test/program_sparsity/paraboloid.jl new file mode 100644 index 00000000..2844c627 --- /dev/null +++ b/test/program_sparsity/paraboloid.jl @@ -0,0 +1,26 @@ +using SparseArrays +using LinearAlgebra +using SparseDiffTools + +struct ParaboloidStruct{T, Tm <: AbstractArray{T,2}, + Tv <: AbstractArray{T}} <: Any where T<:Number + mat::Tm + vec::Tv + xt::Tv + alpha::T +end + +function quad(x::Vector, param) + mat = param.mat + xt = x-param.vec + return 0.5*dot(xt, mat*xt) +end + +function _paraboloidproblem(N::Int; + mat::AbstractArray{T,2} = sparse(Diagonal(float(1:N))), + alpha::T=10.0, + x0::AbstractVector{T} = ones(N)) where T <: Number + hsparsity(quad,x0,ParaboloidStruct(mat, x0, similar(x0), alpha)) +end + +@test isdiag(_paraboloidproblem(10)) diff --git a/test/program_sparsity/testall.jl b/test/program_sparsity/testall.jl index ba746ac3..c26827bc 100644 --- a/test/program_sparsity/testall.jl +++ b/test/program_sparsity/testall.jl @@ -2,3 +2,5 @@ include("common.jl") @testset "Basics" begin include("basics.jl") end @testset "Exploration" begin include("ifsandbuts.jl") end +@testset "Hessian sparsity" begin include("hessian.jl") end +@testset "Paraboloid example" begin include("paraboloid.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index 2577c9a5..fd0df5d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,4 +9,4 @@ using Test @testset "Integration test" begin include("test_integration.jl") end @testset "Special matrices" begin include("test_specialmatrices.jl") end @testset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end -@testset "Jacobian sparsity computation" begin include("program_sparsity/testall.jl") end \ No newline at end of file +@testset "Program sparsity computation" begin include("program_sparsity/testall.jl") end