diff --git a/Project.toml b/Project.toml index 5678c050a..b00eb4531 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -19,5 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Symbolics = "1" ZygoteRules = "0.2" julia = "1.3" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..5293ea297 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -128,6 +128,7 @@ include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") +include("symbolic/Symbolic.jl") include("submodel_macro.jl") end # module diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl new file mode 100644 index 000000000..36630e88d --- /dev/null +++ b/src/symbolic/Symbolic.jl @@ -0,0 +1,82 @@ +module Symbolic + +using ..DynamicPPL: DynamicPPL +import ..DynamicPPL: + Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext + +using Random: Random +using Bijectors: Bijectors +using Distributions +using Symbolics: Symbolics +import Symbolics: SymbolicUtils + +issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true +issym(x) = false + +include("rules.jl") +include("contexts.jl") + +# Allow `num` to appear on the RHS of `~`. +DynamicPPL.check_tilde_rhs(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = x + +symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) +function symbolize( + rng::Random.AbstractRNG, + m::Model, + vi::VarInfo=VarInfo(m); + spl=SampleFromPrior(), + ctx=DefaultContext(), + include_data=false, +) + m(rng, vi, spl, ctx) + θ_orig = vi[spl] + + # Symbolic `logpdf` for fixed observations. + # TODO: don't `collect` once symbolic arrays are mature enough. + Symbolics.@variables θ[1:length(θ_orig)] + vi = VarInfo{Real}(vi, spl, θ) + m(vi, ctx) + + return vi, θ +end + +function dependencies(ctx::SymbolicContext, vn::VarName) + right = ctx.vn2rights[vn] + r = Symbolics.value(right) + + if !issym(r) + # No dependencies. + return [] + end + + args = SymbolicUtils.arguments(r) + return mapreduce(vcat, args) do a + Symbolics.get_variables(a) + end +end +function dependencies(ctx::SymbolicContext, symbolic=false) + vn2var = ctx.vn2var + var2vn = Dict(values(vn2var) .=> keys(vn2var)) + return Dict( + (symbolic ? vn2var[vn] : vn) => + map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for + vn in keys(ctx.vn2var) + ) +end + +function dependencies(m::Model, symbolic=false) + ctx = SymbolicContext(DefaultContext()) + vi = symbolize(m, VarInfo(m); ctx=ctx) + + return dependencies(ctx, symbolic) +end + +function symbolic_logp(m::Model) + vi, θ = symbolize(m) + lp = DynamicPPL.getlogp(vi) + lp_analytic = analytic_rw(Symbolics.value(lp)) + lp_analytic_num = addnum_rw(lp_analytic) + + return lp_analytic_num, θ +end +end diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl new file mode 100644 index 000000000..b11ec2ff1 --- /dev/null +++ b/src/symbolic/contexts.jl @@ -0,0 +1,62 @@ +struct SymbolicContext{Ctx} <: DynamicPPL.AbstractContext + ctx::Ctx + vn2var::Dict + vn2rights::Dict +end +SymbolicContext() = SymbolicContext(DefaultContext()) +SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict()) + +Symbolics.@register Distributions.loglikelihood(dist, x) + +# assume +function DynamicPPL.tilde_assume(ctx::SymbolicContext, right, vn, inds, vi) + if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) + # Distribution is symbolic OR variable is. + ctx.vn2var[vn] = vi[vn] + ctx.vn2rights[vn] = right + end + + return DynamicPPL.tilde_assume(ctx.ctx, right, vn, inds, vi) +end + +function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, inds, vi) + if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) + # Distribution is symbolic OR variable is. + ctx.vn2var[vn] = vi[vn] + ctx.vn2rights[vn] = right + end + + return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +end + +# TODO: Make it more useful when working with symbolic observations. +# observe +function DynamicPPL.tilde_observe(ctx::SymbolicContext, right, left, vi) + if Symbolic.issym(right) || Symbolic.issym(left) + # TODO: implement + end + return DynamicPPL.tilde_observe(ctx.ctx, right, left, vi) +end +function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) + if Symbolic.issym(right) || Symbolic.issym(left) + # TODO: implement + end + + return DynamicPPL.tilde_observe(ctx.ctx, sampler, right, left, vi) +end + +function DynamicPPL.assume(dist::Symbolics.Num, vn::VarName, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, r, DynamicPPL.istrans(vi, vn)) +end + +function DynamicPPL.observe(right::Symbolics.Num, left, vi) + return Distributions.loglikelihood(right, left) +end + +# TODO: Implement `dot_tilde_*` methods. + + diff --git a/src/symbolic/rules.jl b/src/symbolic/rules.jl new file mode 100644 index 000000000..a075b1e98 --- /dev/null +++ b/src/symbolic/rules.jl @@ -0,0 +1,75 @@ +using Bijectors: Bijectors +using Symbolics: Symbolics +using Symbolics.SymbolicUtils + +Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans) + +# Some predicates +isdist(d) = (d isa Type) && (d <: Distribution) +islogpdf(f::Function) = f === Distributions.logpdf +islogpdf(x) = false + +# HACK: Apparently this is needed for disambiguiation. +# TODO: Open issue. +function Symbolics.:<ₑ(a::Real, b::Symbolics.Num) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end +function Symbolics.:<ₑ(a::Symbolics.Num, b::Real) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end + +############# +### Rules ### +############# +# HACK: We'll wrap rewriters to add back `Num`. This way we can get jacobians and whatnot at then end. +const rmnum_rule = @rule (~x) => Symbolics.value(~x) +const addnum_rule = @rule (~x) => Symbolics.Num(~x) + +# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful: +const logpdf_rule = @rule (~x ~ ~d) => + Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); +const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d)) + +# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`. +islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans +islogpdf_with_trans(x) = false +const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => + logpdf(~dist, ~x) + +# Attempt to expand `logpdf` to get analytical expressions. +# The idea is that `getlogpdf(d, args)` should return a method of the following signature: +# +# f(args..., x) +# +# which returns the logpdf. +# HACK: this is very hacky but you get the idea +import Distributions: StatsFuns +function getlogpdf(d, args) + replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf) + + dsym = Symbol(d) + if haskey(replacements, dsym) + return replacements[dsym] + else + return d + end +end + +const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => + getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) + +################# +### Rewriters ### +################# +# TODO: these should probably be instantiated when needed, rather than here. +const analytic_rw = Rewriters.Postwalk( + Rewriters.Chain(( + rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`. + logpdf_with_trans_rule, # 1. Replace `logpdf_with_trans` with `logpdf`. + analytic_rule, # 2. Attempt to replace `logpdf` with analytic expression. + )) +) + +# So we add back `Num` to all terms to allow differentiation. +const rmnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(rmnum_rule)) +const addnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(addnum_rule)) diff --git a/src/varinfo.jl b/src/varinfo.jl index 64c122dc2..9b9da278c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -122,6 +122,11 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) ) end +function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T} + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) + return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) +end + function VarInfo( rng::Random.AbstractRNG, model::Model,