From 783d81d36419fbe883e5062613aa5392acf88115 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Apr 2021 14:37:26 +0200 Subject: [PATCH 01/11] relax constraints to allow use with Symbolics.jl --- src/compiler.jl | 4 +--- src/context_implementations.jl | 6 +++--- src/varinfo.jl | 5 +++++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 39203f5ee..deb22ee2b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -226,9 +226,7 @@ variables. """ function generate_tilde(left, right) @gensym tmpright - top = [:($tmpright = $right), - :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} - || throw(ArgumentError($DISTMSG)))] + top = [:($tmpright = $right),] if left isa Symbol || left isa Expr @gensym out vn inds isassumption diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6b3542acd..671461073 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -119,7 +119,7 @@ end function assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, + dist, vn::VarName, vi, ) @@ -144,12 +144,12 @@ end function observe( spl::Union{SampleFromPrior, SampleFromUniform}, - dist::Distribution, + dist, value, vi, ) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return sum(Distributions.logpdf(dist, value)) end # .~ functions diff --git a/src/varinfo.jl b/src/varinfo.jl index de369cceb..d0f67e903 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -116,6 +116,11 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))) end +function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T} + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) + VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) +end + function VarInfo( rng::Random.AbstractRNG, model::Model, From 0421aed7824fe62b4db335084a59a1b16e82c911 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Apr 2021 05:25:08 +0200 Subject: [PATCH 02/11] added symbolic stuff --- src/DynamicPPL.jl | 5 ++- src/symbolic/Symbolic.jl | 77 ++++++++++++++++++++++++++++++++++++++++ src/symbolic/contexts.jl | 29 +++++++++++++++ src/symbolic/rules.jl | 70 ++++++++++++++++++++++++++++++++++++ 4 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 src/symbolic/Symbolic.jl create mode 100644 src/symbolic/contexts.jl create mode 100644 src/symbolic/rules.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e369fee3f..193e71d7a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -95,7 +95,9 @@ export AbstractVarInfo, logjoint, pointwise_loglikelihoods, # Convenience macros - @addlogprob! + @addlogprob!, +# Symbolics + Symbolic # Reexport using Distributions: loglikelihood @@ -123,5 +125,6 @@ include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") +include("symbolic/Symbolic.jl") end # module diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl new file mode 100644 index 000000000..2a50d67e3 --- /dev/null +++ b/src/symbolic/Symbolic.jl @@ -0,0 +1,77 @@ +module Symbolic + +import ..DynamicPPL +import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext + +import Random +import Bijectors +using Distributions +import Symbolics +import Symbolics: SymbolicUtils + +issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true +issym(x) = false + +include("rules.jl") +include("contexts.jl") + +symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) +function symbolize( + rng::Random.AbstractRNG, + m::Model, + vi::VarInfo = VarInfo(m); + spl = SampleFromPrior(), + ctx = DefaultContext(), + include_data = false +) + m(rng, vi, spl, ctx); + θ_orig = vi[spl] + + # Symbolic `logpdf` for fixed observations. + Symbolics.@variables θ[1:length(θ_orig)] + vi = VarInfo(vi, spl, θ, zero(eltype(θ))); + m(rng, vi, spl, ctx); + + return vi, θ +end + +function dependencies(ctx::SymbolicContext, vn::VarName) + right = ctx.vn2rights[vn] + r = Symbolics.value(right) + + if !issym(r) + # No dependencies. + return [] + end + + args = SymbolicUtils.arguments(r) + return mapreduce(vcat, args) do a + Symbolics.get_variables(a) + end +end +function dependencies(ctx::SymbolicContext, symbolic = false) + vn2var = ctx.vn2var + var2vn = Dict(values(vn2var) .=> keys(vn2var)) + return Dict( + (symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) + for vn in keys(ctx.vn2var) + ) +end + +function dependencies(m::Model, symbolic = false) + ctx = SymbolicContext(DefaultContext()) + vi = symbolize(m, VarInfo(m), ctx = ctx) + + return dependencies(ctx, symbolic) +end + + +function symbolic_logp(m::Model) + vi, θ = symbolize(m) + lp = DynamicPPL.getlogp(vi) + lp_analytic = analytic_rw(Symbolics.value(lp)) + lp_analytic_num = addnum_rw(lp_analytic) + + return lp_analytic_num, θ +end +end diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl new file mode 100644 index 000000000..98c4dba0a --- /dev/null +++ b/src/symbolic/contexts.jl @@ -0,0 +1,29 @@ +struct SymbolicContext{Ctx} <: DynamicPPL.AbstractContext + ctx::Ctx + vn2var::Dict + vn2rights::Dict +end +SymbolicContext() = SymbolicContext(DefaultContext()) +SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict()) + +# assume +function DynamicPPL.tilde(rng, ctx::SymbolicContext, sampler, right, vn::VarName, inds, vi) + if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) + # Distribution is symbolic OR variable is. + ctx.vn2var[vn] = vi[vn] + ctx.vn2rights[vn] = right + end + + return DynamicPPL.tilde(rng, ctx.ctx, sampler, right, vn, inds, vi) +end + + +# TODO: Make it more useful when working with symbolic observations. +# observe +function DynamicPPL.tilde(ctx::SymbolicContext, sampler, right, left, vi) + if Symbolic.issym(right) || Symbolic.issym(left) + # TODO: implement + end + + return DynamicPPL.tilde(ctx.ctx, sampler, right, left, vi) +end diff --git a/src/symbolic/rules.jl b/src/symbolic/rules.jl new file mode 100644 index 000000000..b3a4f991c --- /dev/null +++ b/src/symbolic/rules.jl @@ -0,0 +1,70 @@ +import Bijectors +import Symbolics +using Symbolics.SymbolicUtils + +Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans) + +# Some predicates +isdist(d) = (d isa Type) && (d <: Distribution) +islogpdf(f::Function) = f === Distributions.logpdf +islogpdf(x) = false + +# HACK: Apparently this is needed for disambiguiation. +# TODO: Open issue. +Symbolics.:<ₑ(a::Real, b::Symbolics.Num) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) + +############# +### Rules ### +############# +# HACK: We'll wrap rewriters to add back `Num`. This way we can get jacobians and whatnot at then end. +const rmnum_rule = @rule (~x) => Symbolics.value(~x) +const addnum_rule = @rule (~x) => Symbolics.Num(~x) + +# In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful: +const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); +const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d)) + +# We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`. +islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans +islogpdf_with_trans(x) = false +const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x) + +# Attempt to expand `logpdf` to get analytical expressions. +# The idea is that `getlogpdf(d, args)` should return a method of the following signature: +# +# f(args..., x) +# +# which returns the logpdf. +# HACK: this is very hacky but you get the idea +import Distributions: StatsFuns +function getlogpdf(d, args) + replacements = Dict( + :Normal => StatsFuns.normlogpdf, + :Gamma => StatsFuns.gammalogpdf + ) + + dsym = Symbol(d) + if haskey(replacements, dsym) + return replacements[dsym] + else + return d + end +end + +const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) + + +################# +### Rewriters ### +################# +const analytic_rw = Rewriters.Postwalk( + Rewriters.Chain(( + rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`. + logpdf_with_trans_rule, # 1. Replace `logpdf_with_trans` with `logpdf`. + analytic_rule, # 2. Attempt to replace `logpdf` with analytic expression. + )) +) + +# So we add back `Num` to all terms to allow differentiation. +const rmnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(rmnum_rule)) +const addnum_rw = Rewriters.Postwalk(Rewriters.PassThrough(addnum_rule)) From 14ad023c59586e61f451dbce8f942fc7495762e9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Apr 2021 05:26:09 +0200 Subject: [PATCH 03/11] added Symbolic as dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 07d68608e..e04f3012f 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] AbstractMCMC = "2, 3.0" From 03067a00c468f397df2f764cd1b6ce20fb9cad4a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Apr 2021 05:26:43 +0200 Subject: [PATCH 04/11] added version bound to Symbolics --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index e04f3012f..cf6dda58b 100644 --- a/Project.toml +++ b/Project.toml @@ -20,5 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7" Distributions = "0.23.8, 0.24" MacroTools = "0.5.6" +Symbolics = "0.1.24" NaturalSort = "1" julia = "1.3" From 785242ea0c5b83f30786b946cd16a80405e5492b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Apr 2021 11:34:07 +0200 Subject: [PATCH 05/11] added another method for disambiguiation --- src/symbolic/rules.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/symbolic/rules.jl b/src/symbolic/rules.jl index b3a4f991c..d0a8421ce 100644 --- a/src/symbolic/rules.jl +++ b/src/symbolic/rules.jl @@ -12,6 +12,7 @@ islogpdf(x) = false # HACK: Apparently this is needed for disambiguiation. # TODO: Open issue. Symbolics.:<ₑ(a::Real, b::Symbolics.Num) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +Symbolics.:<ₑ(a::Symbolics.Num, b::Real) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) ############# ### Rules ### @@ -57,6 +58,7 @@ const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpd ################# ### Rewriters ### ################# +# TODO: these should probably be instantiated when needed, rather than here. const analytic_rw = Rewriters.Postwalk( Rewriters.Chain(( rmnum_rule, # 0. Remove `Num` so we're only working stuff from `SymbolicUtils.jl`. From a3a9742830e2318ea817cbbc4425b6d26ec6caf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 15:23:45 +0100 Subject: [PATCH 06/11] update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cf6dda58b..2396a6a9d 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7" Distributions = "0.23.8, 0.24" MacroTools = "0.5.6" -Symbolics = "0.1.24" NaturalSort = "1" +Symbolics = "0.1.24" julia = "1.3" From 817533ce893f86e23b0a42d38c9edc2cac070de6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 16:29:44 +0100 Subject: [PATCH 07/11] updated symbolics versioning --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b0fc8b806..3d7f636f9 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" -NaturalSort = "1" -Symbolics = "0.1.24" +Symbolics = "1" ZygoteRules = "0.2" julia = "1.3" From daa1319feed0670000c55da68d74f28c2af5d176 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 16:30:55 +0100 Subject: [PATCH 08/11] added some hacks to make it at least run --- src/symbolic/contexts.jl | 22 ++++++++++++++++++---- src/varinfo.jl | 5 +++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl index 98c4dba0a..cc9eb62fc 100644 --- a/src/symbolic/contexts.jl +++ b/src/symbolic/contexts.jl @@ -7,23 +7,37 @@ SymbolicContext() = SymbolicContext(DefaultContext()) SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict()) # assume -function DynamicPPL.tilde(rng, ctx::SymbolicContext, sampler, right, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, inds, vi) if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) # Distribution is symbolic OR variable is. ctx.vn2var[vn] = vi[vn] ctx.vn2rights[vn] = right end - return DynamicPPL.tilde(rng, ctx.ctx, sampler, right, vn, inds, vi) + return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) end # TODO: Make it more useful when working with symbolic observations. # observe -function DynamicPPL.tilde(ctx::SymbolicContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) if Symbolic.issym(right) || Symbolic.issym(left) # TODO: implement end - return DynamicPPL.tilde(ctx.ctx, sampler, right, left, vi) + return DynamicPPL.tilde_observe(ctx.ctx, sampler, right, left, vi) end + +function DynamicPPL.assume(dist::Symbolics.Num, vn::VarName, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], DynamicPPL.istrans(vi, vn)) +end + +function DynamicPPL.observe(right::Symbolics.Num, left, vi) + return Distributions.loglikelihood(right, left) +end + +Symbolics.@register Distributions.loglikelihood(dist, x) diff --git a/src/varinfo.jl b/src/varinfo.jl index b43f39a9f..4bad2a50a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -127,6 +127,11 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T} VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) end +function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T} + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) + VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) +end + function VarInfo( rng::Random.AbstractRNG, model::Model, From 70c99970fc77a2b5a1263dc6b895fc8c76078d43 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 16:31:39 +0100 Subject: [PATCH 09/11] updated symbolize --- src/symbolic/Symbolic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl index 2a50d67e3..13c84de88 100644 --- a/src/symbolic/Symbolic.jl +++ b/src/symbolic/Symbolic.jl @@ -29,8 +29,8 @@ function symbolize( # Symbolic `logpdf` for fixed observations. Symbolics.@variables θ[1:length(θ_orig)] - vi = VarInfo(vi, spl, θ, zero(eltype(θ))); - m(rng, vi, spl, ctx); + vi = VarInfo{Real}(vi, spl, θ, 0.0); + m(vi, ctx); return vi, θ end From 23141be4cfa82fa0a5fdf0a1d5829e0d802c6771 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 16:35:05 +0100 Subject: [PATCH 10/11] formatting --- src/symbolic/Symbolic.jl | 40 +++++++++++++++++++++------------------- src/symbolic/contexts.jl | 1 - src/symbolic/rules.jl | 29 ++++++++++++++++------------- src/varinfo.jl | 4 ++-- 4 files changed, 39 insertions(+), 35 deletions(-) diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl index 13c84de88..696103e19 100644 --- a/src/symbolic/Symbolic.jl +++ b/src/symbolic/Symbolic.jl @@ -1,15 +1,16 @@ module Symbolic -import ..DynamicPPL -import ..DynamicPPL: Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext +using ..DynamicPPL: DynamicPPL +import ..DynamicPPL: + Model, VarInfo, AbstractSampler, SampleFromPrior, VarName, DefaultContext -import Random -import Bijectors +using Random: Random +using Bijectors: Bijectors using Distributions -import Symbolics +using Symbolics: Symbolics import Symbolics: SymbolicUtils -issym(x::Union{Symbolics.Num, SymbolicUtils.Symbolic}) = true +issym(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = true issym(x) = false include("rules.jl") @@ -19,18 +20,19 @@ symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) function symbolize( rng::Random.AbstractRNG, m::Model, - vi::VarInfo = VarInfo(m); - spl = SampleFromPrior(), - ctx = DefaultContext(), - include_data = false + vi::VarInfo=VarInfo(m); + spl=SampleFromPrior(), + ctx=DefaultContext(), + include_data=false, ) - m(rng, vi, spl, ctx); + m(rng, vi, spl, ctx) θ_orig = vi[spl] # Symbolic `logpdf` for fixed observations. + # TODO: don't `collect` once symbolic arrays are mature enough. Symbolics.@variables θ[1:length(θ_orig)] - vi = VarInfo{Real}(vi, spl, θ, 0.0); - m(vi, ctx); + vi = VarInfo{Real}(vi, spl, θ, 0.0) + m(vi, ctx) return vi, θ end @@ -49,23 +51,23 @@ function dependencies(ctx::SymbolicContext, vn::VarName) Symbolics.get_variables(a) end end -function dependencies(ctx::SymbolicContext, symbolic = false) +function dependencies(ctx::SymbolicContext, symbolic=false) vn2var = ctx.vn2var var2vn = Dict(values(vn2var) .=> keys(vn2var)) return Dict( - (symbolic ? vn2var[vn] : vn) => map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) - for vn in keys(ctx.vn2var) + (symbolic ? vn2var[vn] : vn) => + map(x -> symbolic ? x : var2vn[x], dependencies(ctx, vn)) for + vn in keys(ctx.vn2var) ) end -function dependencies(m::Model, symbolic = false) +function dependencies(m::Model, symbolic=false) ctx = SymbolicContext(DefaultContext()) - vi = symbolize(m, VarInfo(m), ctx = ctx) + vi = symbolize(m, VarInfo(m); ctx=ctx) return dependencies(ctx, symbolic) end - function symbolic_logp(m::Model) vi, θ = symbolize(m) lp = DynamicPPL.getlogp(vi) diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl index cc9eb62fc..faf371e8f 100644 --- a/src/symbolic/contexts.jl +++ b/src/symbolic/contexts.jl @@ -17,7 +17,6 @@ function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, return DynamicPPL.tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) end - # TODO: Make it more useful when working with symbolic observations. # observe function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) diff --git a/src/symbolic/rules.jl b/src/symbolic/rules.jl index d0a8421ce..a075b1e98 100644 --- a/src/symbolic/rules.jl +++ b/src/symbolic/rules.jl @@ -1,5 +1,5 @@ -import Bijectors -import Symbolics +using Bijectors: Bijectors +using Symbolics: Symbolics using Symbolics.SymbolicUtils Symbolics.@register Bijectors.logpdf_with_trans(dist, r, istrans) @@ -11,8 +11,12 @@ islogpdf(x) = false # HACK: Apparently this is needed for disambiguiation. # TODO: Open issue. -Symbolics.:<ₑ(a::Real, b::Symbolics.Num) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) -Symbolics.:<ₑ(a::Symbolics.Num, b::Real) = Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +function Symbolics.:<ₑ(a::Real, b::Symbolics.Num) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end +function Symbolics.:<ₑ(a::Symbolics.Num, b::Real) + return Symbolics.:<ₑ(Symbolics.value(a), Symbolics.value(b)) +end ############# ### Rules ### @@ -22,13 +26,15 @@ const rmnum_rule = @rule (~x) => Symbolics.value(~x) const addnum_rule = @rule (~x) => Symbolics.Num(~x) # In the case where we want to work directly with the `x ~ Distribution` statements, the following rules can be useful: -const logpdf_rule = @rule (~x ~ ~d) => Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); +const logpdf_rule = @rule (~x ~ ~d) => + Distributions.logpdf(Symbolics.Num(~d), Symbolics.Num(~x)); const rand_rule = @rule (~x ~ ~d) => Distributions.rand(Symbolics.Num(~d)) # We don't want to trace into `Bijectors.logpdf_with_trans`, so we just replace it with `logpdf`. islogpdf_with_trans(f::Function) = f === Bijectors.logpdf_with_trans islogpdf_with_trans(x) = false -const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => logpdf(~dist, ~x) +const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istrans) => + logpdf(~dist, ~x) # Attempt to expand `logpdf` to get analytical expressions. # The idea is that `getlogpdf(d, args)` should return a method of the following signature: @@ -39,11 +45,8 @@ const logpdf_with_trans_rule = @rule (~f::islogpdf_with_trans)(~dist, ~x, ~istra # HACK: this is very hacky but you get the idea import Distributions: StatsFuns function getlogpdf(d, args) - replacements = Dict( - :Normal => StatsFuns.normlogpdf, - :Gamma => StatsFuns.gammalogpdf - ) - + replacements = Dict(:Normal => StatsFuns.normlogpdf, :Gamma => StatsFuns.gammalogpdf) + dsym = Symbol(d) if haskey(replacements, dsym) return replacements[dsym] @@ -52,8 +55,8 @@ function getlogpdf(d, args) end end -const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) - +const analytic_rule = @rule (~f::islogpdf)((~d::isdist)(~~args), ~x) => + getlogpdf(~d, ~~args)(map(Symbolics.Num, (~~args))..., Symbolics.Num(~x)) ################# ### Rewriters ### diff --git a/src/varinfo.jl b/src/varinfo.jl index 4bad2a50a..d7991c73d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,12 +124,12 @@ end function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T} md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) + return VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) end function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T} md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) + return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi))) end function VarInfo( From b19700efcc51f5dcaf432dd0d1e0d815deb9892c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 23:47:31 +0100 Subject: [PATCH 11/11] make it runnable on newer package versions --- src/symbolic/Symbolic.jl | 5 ++++- src/symbolic/contexts.jl | 24 ++++++++++++++++++++++-- src/varinfo.jl | 5 ----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/symbolic/Symbolic.jl b/src/symbolic/Symbolic.jl index 696103e19..36630e88d 100644 --- a/src/symbolic/Symbolic.jl +++ b/src/symbolic/Symbolic.jl @@ -16,6 +16,9 @@ issym(x) = false include("rules.jl") include("contexts.jl") +# Allow `num` to appear on the RHS of `~`. +DynamicPPL.check_tilde_rhs(x::Union{Symbolics.Num,SymbolicUtils.Symbolic}) = x + symbolize(args...; kwargs...) = symbolize(Random.GLOBAL_RNG, args...; kwargs...) function symbolize( rng::Random.AbstractRNG, @@ -31,7 +34,7 @@ function symbolize( # Symbolic `logpdf` for fixed observations. # TODO: don't `collect` once symbolic arrays are mature enough. Symbolics.@variables θ[1:length(θ_orig)] - vi = VarInfo{Real}(vi, spl, θ, 0.0) + vi = VarInfo{Real}(vi, spl, θ) m(vi, ctx) return vi, θ diff --git a/src/symbolic/contexts.jl b/src/symbolic/contexts.jl index faf371e8f..b11ec2ff1 100644 --- a/src/symbolic/contexts.jl +++ b/src/symbolic/contexts.jl @@ -6,7 +6,19 @@ end SymbolicContext() = SymbolicContext(DefaultContext()) SymbolicContext(ctx) = SymbolicContext(ctx, Dict(), Dict()) +Symbolics.@register Distributions.loglikelihood(dist, x) + # assume +function DynamicPPL.tilde_assume(ctx::SymbolicContext, right, vn, inds, vi) + if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) + # Distribution is symbolic OR variable is. + ctx.vn2var[vn] = vi[vn] + ctx.vn2rights[vn] = right + end + + return DynamicPPL.tilde_assume(ctx.ctx, right, vn, inds, vi) +end + function DynamicPPL.tilde_assume(rng, ctx::SymbolicContext, sampler, right, vn, inds, vi) if Symbolic.issym(right) || (haskey(vi, vn) && Symbolic.issym(vi[vn])) # Distribution is symbolic OR variable is. @@ -19,6 +31,12 @@ end # TODO: Make it more useful when working with symbolic observations. # observe +function DynamicPPL.tilde_observe(ctx::SymbolicContext, right, left, vi) + if Symbolic.issym(right) || Symbolic.issym(left) + # TODO: implement + end + return DynamicPPL.tilde_observe(ctx.ctx, right, left, vi) +end function DynamicPPL.tilde_observe(ctx::SymbolicContext, sampler, right, left, vi) if Symbolic.issym(right) || Symbolic.issym(left) # TODO: implement @@ -32,11 +50,13 @@ function DynamicPPL.assume(dist::Symbolics.Num, vn::VarName, vi) error("variable $vn does not exist") end r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, vi[vn], DynamicPPL.istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, DynamicPPL.istrans(vi, vn)) end function DynamicPPL.observe(right::Symbolics.Num, left, vi) return Distributions.loglikelihood(right, left) end -Symbolics.@register Distributions.loglikelihood(dist, x) +# TODO: Implement `dot_tilde_*` methods. + + diff --git a/src/varinfo.jl b/src/varinfo.jl index 15a9bb02b..9b9da278c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -122,11 +122,6 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) ) end -function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector, lp::T) where {T} - md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) - return VarInfo(md, Base.RefValue{T}(lp), Ref(get_num_produce(old_vi))) -end - function VarInfo{T}(old_vi::TypedVarInfo, spl, x::AbstractVector) where {T} md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) return VarInfo(md, Base.RefValue{T}(0.0), Ref(get_num_produce(old_vi)))