From e32b59fc0c51d89b68fff66fa790da34f646b7b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 May 2021 20:36:20 +0200 Subject: [PATCH 001/107] initial stuff --- test/benchmarks/Project.toml | 6 ++ test/benchmarks/benchmark_body.jmd | 30 ++++++++++ test/benchmarks/benchmarks.jmd | 96 ++++++++++++++++++++++++++++++ test/benchmarks/utils.jl | 49 +++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 test/benchmarks/Project.toml create mode 100644 test/benchmarks/benchmark_body.jmd create mode 100644 test/benchmarks/benchmarks.jmd create mode 100644 test/benchmarks/utils.jl diff --git a/test/benchmarks/Project.toml b/test/benchmarks/Project.toml new file mode 100644 index 000000000..34406308c --- /dev/null +++ b/test/benchmarks/Project.toml @@ -0,0 +1,6 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DiffUtils = "8294860b-85a6-42f8-8c35-d911f667b5f6" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9" diff --git a/test/benchmarks/benchmark_body.jmd b/test/benchmarks/benchmark_body.jmd new file mode 100644 index 000000000..a82d9ff52 --- /dev/null +++ b/test/benchmarks/benchmark_body.jmd @@ -0,0 +1,30 @@ +```julia +@time model_def(data)(); +``` + +```julia +m = time_model_def(model_def, data); +``` + +```julia +suite = make_suite(m); +run(suite) +``` + +```julia; wrap=false +typed = typed_code(m) +``` + +```julia; echo=false; results="hidden" +# Serialize the output of `typed_code` so we can compare later. +haskey(WEAVE_ARGS, :prefix) && serialize("$(WEAVE_ARGS[:prefix])_$(m.name).jls", string(typed)); +``` + +```julia; wrap=false +if haskey(WEAVE_ARGS, :prefix_old) + # We want to compare the generated code to the previous version. + import DiffUtils + typed_old = deserialize("$(WEAVE_ARGS[:prefix_old])_$(m.name).jls"); + DiffUtils.diff(typed_old, string(typed), width=130) +end +``` diff --git a/test/benchmarks/benchmarks.jmd b/test/benchmarks/benchmarks.jmd new file mode 100644 index 000000000..0cc6e1de2 --- /dev/null +++ b/test/benchmarks/benchmarks.jmd @@ -0,0 +1,96 @@ +# Benchmarks + +## Setup + +```julia +using BenchmarkTools, DynamicPPL, Distributions, Serialization +``` + +```julia +include("utils.jl") +``` + +## Models + +### `demo1` + +```julia +@model function demo1(x) + m ~ Normal() + x ~ Normal(m, 1) + + return (m = m, x = x) +end + +model_def = demo1; +data = 1.0; +``` + +```julia; results="markup"; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +### `demo2` + +```julia +@model function demo2(y) + # Our prior belief about the probability of heads in a coin. + p ~ Beta(1, 1) + + # The number of observations. + N = length(y) + for n in 1:N + # Heads or tails of a coin are drawn from a Bernoulli distribution. + y[n] ~ Bernoulli(p) + end +end + +model_def = demo2; +data = rand(0:1, 10); +``` + +```julia; results="markup"; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +### `demo3` + +```julia +@model function demo3(x) + D, N = size(x) + + # Draw the parameters for cluster 1. + μ1 ~ Normal() + + # Draw the parameters for cluster 2. + μ2 ~ Normal() + + μ = [μ1, μ2] + + # Comment out this line if you instead want to draw the weights. + w = [0.5, 0.5] + + # Draw assignments for each datum and generate it from a multivariate normal. + k = Vector{Int}(undef, N) + for i in 1:N + k[i] ~ Categorical(w) + x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.) + end + return k +end + +model_def = demo3 + +# Construct 30 data points for each cluster. +N = 30 + +# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. +μs = [-3.5, 0.0] + +# Construct the data points. +data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` diff --git a/test/benchmarks/utils.jl b/test/benchmarks/utils.jl new file mode 100644 index 000000000..f6c31c1c2 --- /dev/null +++ b/test/benchmarks/utils.jl @@ -0,0 +1,49 @@ +using DynamicPPL +using BenchmarkTools + +import Weave +import Markdown + +function time_model_def(model_def, args...) + return @time model_def(args...) +end + +function benchmark_untyped_varinfo!(suite, m) + vi = VarInfo() + # Populate. + m(vi) + # Evaluate. + suite["evaluation_untyped"] = @benchmarkable $m($vi) + return suite +end + +function benchmark_typed_varinfo!(suite, m) + # Populate. + vi = VarInfo(m) + # Evaluate. + suite["evaluation_typed"] = @benchmarkable $m($vi) + return suite +end + +function typed_code(m, vi = VarInfo(m)) + rng = DynamicPPL.Random.MersenneTwister(42); + spl = DynamicPPL.SampleFromPrior() + ctx = DynamicPPL.DefaultContext() + + return Main.@code_typed m.f(rng, m, vi, spl, ctx, m.args...) +end + +function make_suite(m) + suite = BenchmarkGroup() + benchmark_untyped_varinfo!(suite, m) + benchmark_typed_varinfo!(suite, m) + + return suite +end + +function weave_child(indoc; mod, args, kwargs...) + doc = Weave.WeaveDoc(indoc, nothing) + doc = Weave.run_doc(doc, doctype = "github", mod = mod, args = args, kwargs...) + rendered = Weave.render_doc(doc) + return display(Markdown.parse(rendered)) +end From e83a6711da574d6bc5adcdc441c713a9ffc73550 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 May 2021 22:31:55 +0200 Subject: [PATCH 002/107] moved benchmark folder and added README --- {test/benchmarks => benchmarks}/Project.toml | 0 benchmarks/README.md | 8 ++++++++ {test/benchmarks => benchmarks}/benchmark_body.jmd | 0 {test/benchmarks => benchmarks}/benchmarks.jmd | 0 {test/benchmarks => benchmarks}/utils.jl | 0 5 files changed, 8 insertions(+) rename {test/benchmarks => benchmarks}/Project.toml (100%) create mode 100644 benchmarks/README.md rename {test/benchmarks => benchmarks}/benchmark_body.jmd (100%) rename {test/benchmarks => benchmarks}/benchmarks.jmd (100%) rename {test/benchmarks => benchmarks}/utils.jl (100%) diff --git a/test/benchmarks/Project.toml b/benchmarks/Project.toml similarity index 100% rename from test/benchmarks/Project.toml rename to benchmarks/Project.toml diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..565217753 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,8 @@ +To run the benchmarks, simply do: +```sh +julia --project -e 'using Weave; Weave.weave("benchmarks.jmd", doctype="github", args=Dict(:benchmarkbody => "benchmark_body.jmd"));' +``` + +Furthermore: +- If you want to save the output of `code_typed` for the evaluator of the different models, add a `:prefix => "myprefix"` to the `args`. +- If `:prefix_old` is specified in `args`, a `diff` of the `code_typed` loaded using `:prefix_old` and the output of `code_typed` for the current run will be included in the weaved document. diff --git a/test/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd similarity index 100% rename from test/benchmarks/benchmark_body.jmd rename to benchmarks/benchmark_body.jmd diff --git a/test/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd similarity index 100% rename from test/benchmarks/benchmarks.jmd rename to benchmarks/benchmarks.jmd diff --git a/test/benchmarks/utils.jl b/benchmarks/utils.jl similarity index 100% rename from test/benchmarks/utils.jl rename to benchmarks/utils.jl From 3ab2beed11a07bbb2b7009e2956252dbf1998463 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:10:31 +0100 Subject: [PATCH 003/107] unwrap distributions and varnames at model-level --- src/compiler.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..0be9d4d44 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -52,6 +52,45 @@ end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +""" + unwrap_right_vn(right, vn) +Return the unwrapped distribution on the right-hand side and variable name on the left-hand +side of a `~` expression such as `x ~ Normal()`. +This is used mainly to unwrap `NamedDist` distributions. +""" +unwrap_right_vn(right, vn) = right, vn +unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) + +""" + unwrap_right_left_vns(context, right, left, vns) +Return the unwrapped distributions on the right-hand side and values and variable names on the +left-hand side of a `.~` expression such as `x .~ Normal()`. +This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the +variables. +""" +unwrap_right_left_vns(right, left, vns) = right, left, vns +function unwrap_right_left_vns(right::NamedDist, left, vns) + return unwrap_right_left_vns(right.dist, left, right.name) +end +function unwrap_right_left_vns( + right::MultivariateDistribution, left::AbstractMatrix, vn::VarName +) + vns = map(axes(left, 2)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end +function unwrap_right_left_vns( + right::Union{Distribution,AbstractArray{<:Distribution}}, + left::AbstractArray, + vn::VarName, +) + vns = map(CartesianIndices(left)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end + ################# # Main Compiler # ################# @@ -264,8 +303,9 @@ function generate_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -314,9 +354,9 @@ function generate_dot_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) From a549d1fc7b5b149b0fdf96754f366e825bcf31e9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:10:54 +0100 Subject: [PATCH 004/107] removed _tilde and renamed tilde_assume and others --- src/context_implementations.jl | 175 +++++++++++---------------------- 1 file changed, 60 insertions(+), 115 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afc5e4da3..f1977fe80 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,79 +18,72 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return _tilde(rng, sampler, right, vn, vi) +function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) + return assume(rng, sampler, right, vn, vi) end -function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) +function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) if ctx.vars !== nothing vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, right, vn, vi) + return assume(rng, sampler, right, vn, vi) end -function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) +function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, NoDist(right), vn, vi) + return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) +function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) + return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) end -function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) +function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) + return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) end """ - tilde_assume(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end -function _tilde(rng, sampler, right, vn::VarName, vi) - return assume(rng, sampler, right, vn, vi) -end -function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi) - return _tilde(rng, sampler, right.dist, right.name, vi) -end - # observe -function tilde(ctx::DefaultContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) +function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) + return observe(sampler, right, left, vi) end -function tilde(ctx::PriorContext, sampler, right, left, vi) +function tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end -function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) +function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return observe(sampler, right, left, vi) end -function tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) +function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) end -function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(ctx.ctx, sampler, right, left, vi) +function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) + return tilde_observe(ctx.ctx, sampler, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end @@ -103,7 +96,7 @@ return the observed value. Falls back to `tilde(ctx, sampler, right, left, vi)`. """ -function tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vi) logp = tilde(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left @@ -151,80 +144,44 @@ end # .~ functions # assume -function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - return _dot_tilde(rng, sampler, dist, left, vns, vi) +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) + return dot_assume(rng, sampler, right, left, vns, vi) end -function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) +function dot_tilde_assume(rng, ctx::LikelihoodContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} + if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, NoDist.(dist), left, vns, vi) + return dot_assume(rng, sampler, NoDist.(right), left, vns, vi) end -function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) + return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) end -function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) +function dot_tilde_assume(rng, ctx::PriorContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, dist, left, vns, vi) + return dot_assume(rng, sampler, right, left, vns, vi) end """ - dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -function get_vns_and_dist(dist::NamedDist, var, vn::VarName) - return get_vns_and_dist(dist.dist, var, dist.name) -end -function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) - getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) - return getvn.(1:size(var, 2)), dist -end -function get_vns_and_dist( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vn::VarName -) - getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind))) - return getvn.(CartesianIndices(var)), dist -end - -function _dot_tilde(rng, sampler, right, left, vns::AbstractArray{<:VarName}, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( - rng, - sampler::AbstractSampler, - right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, - left::AbstractMatrix{>:AbstractVector}, - vn::AbstractVector{<:VarName}, - vi, -) - return throw(DimensionMismatch(AMBIGUITY_MSG)) -end - function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -348,61 +305,49 @@ function set_val!( end # observe -function dot_tilde(ctx::DefaultContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) +function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) end -function dot_tilde(ctx::PriorContext, sampler, right, left, vi) +function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end -function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) end -function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end """ - dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, sampler, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. """ -function dot_tilde_observe(ctx, sampler, right, left, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end -function _dot_tilde(sampler, right, left::AbstractArray, vi) - return dot_observe(sampler, right, left, vi) -end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( - sampler::AbstractSampler, - right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, - left::AbstractMatrix{>:AbstractVector}, - vi, -) - return throw(DimensionMismatch(AMBIGUITY_MSG)) -end - function dot_observe( spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, From e0f77bc67c4a51f1eab941323e30634a0e8c5c02 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:11:33 +0100 Subject: [PATCH 005/107] formatting --- src/context_implementations.jl | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f1977fe80..9f78a10ec 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -147,7 +147,16 @@ end function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) return dot_assume(rng, sampler, right, left, vns, vi) end -function dot_tilde_assume(rng, ctx::LikelihoodContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} +function dot_tilde_assume( + rng, + ctx::LikelihoodContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) @@ -158,7 +167,16 @@ end function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) end -function dot_tilde_assume(rng, ctx::PriorContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} +function dot_tilde_assume( + rng, + ctx::PriorContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} if ctx.vars !== nothing var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) From 8e4fa91db88448d2d0a73fbe3fc86b7644b19223 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:39:48 +0100 Subject: [PATCH 006/107] updated compiler for new tilde-methods --- src/compiler.jl | 14 +++++++------- src/context_implementations.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0be9d4d44..20d8bf8ef 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -62,7 +62,7 @@ unwrap_right_vn(right, vn) = right, vn unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) """ - unwrap_right_left_vns(context, right, left, vns) + unwrap_right_left_vns(right, left, vns) Return the unwrapped distributions on the right-hand side and values and variable names on the left-hand side of a `.~` expression such as `x .~ Normal()`. This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the @@ -281,7 +281,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -299,7 +299,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume)( + $left = $(DynamicPPL.tilde_assume!)( __rng__, __context__, __sampler__, @@ -310,7 +310,7 @@ function generate_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -332,7 +332,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -350,7 +350,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __rng__, __context__, __sampler__, @@ -361,7 +361,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f78a10ec..8d4c5c2e2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -97,7 +97,7 @@ return the observed value. Falls back to `tilde(ctx, sampler, right, left, vi)`. """ function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde(ctx, sampler, right, left, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end From 1c9a2d58e1e4e4965d8b098f8c43027f867887c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 21:02:45 +0100 Subject: [PATCH 007/107] fixed calls to dot_assume --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8d4c5c2e2..0698b6cdf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -145,7 +145,7 @@ end # assume function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) - return dot_assume(rng, sampler, right, left, vns, vi) + return dot_assume(rng, sampler, right, vns, left, vi) end function dot_tilde_assume( rng, @@ -162,7 +162,7 @@ function dot_tilde_assume( set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) end - return dot_assume(rng, sampler, NoDist.(right), left, vns, vi) + return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) @@ -182,7 +182,7 @@ function dot_tilde_assume( set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) end - return dot_assume(rng, sampler, right, left, vns, vi) + return dot_assume(rng, sampler, right, vns, left, vi) end """ From d70e1be46058e912121b41dd0e2f0724c57474c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:21:20 +0100 Subject: [PATCH 008/107] added sampling context and unwrap_childcontext --- src/contexts.jl | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..1ee43f2b2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,39 @@ +""" + unwrap_childcontext(context::AbstractContext) + +Return a tuple of the child context of a `context`, or `nothing` if the context does +not wrap any other context, and a function `f(c::AbstractContext)` that constructs +an instance of `context` in which the child context is replaced with `c`. + +Falls back to `(nothing, _ -> context)`. +""" +function unwrap_childcontext(context::AbstractContext) + reconstruct_context(@nospecialize(x)) = context + return nothing, reconstruct_context +end + +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + +function unwrap_childcontext(context::SamplingContext) + child = context.context + function reconstruct_samplingcontext(c::AbstractContext) + return SamplingContext(context.rng, context.sampler, c) + end + return child, reconstruct_samplingcontext +end + """ struct DefaultContext <: AbstractContext end @@ -53,6 +89,25 @@ function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints / batch_size) end +function unwrap_childcontext(context::MiniBatchContext) + child = context.context + function reconstruct_minibatchcontext(c::AbstractContext) + return MiniBatchContext(c, context.loglike_scalar) + end + return child, reconstruct_minibatchcontext +end + +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext ctx::C end @@ -81,3 +136,11 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +function unwrap_childcontext(context::PrefixContext{P}) where {P} + child = context.context + function reconstruct_prefixcontext(c::AbstractContext) + return PrefixContext{P}(c) + end + return child, reconstruct_prefixcontext +end From f74399031eca9b5eaab824c16416a825c489f59c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:31 +0100 Subject: [PATCH 009/107] updated tilde methods --- src/context_implementations.jl | 451 +++++++++++++++++++++++++-------- 1 file changed, 352 insertions(+), 99 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..8aa8ddfca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,28 +18,103 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return assume(rng, sampler, right, vn, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) + else + tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) + end +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, inds, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, right, vn, vi) + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, NoDist(right), vn, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) + return assume(rng, sampler, NoDist(right), vn, inds, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) end """ @@ -50,27 +125,76 @@ accumulate the log probability, and return the sampled value. Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vname, vinds, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vi) +end + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return tilde_observe( + context.ctx, right, left, prefix(context, vname), vinds, vi + ) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.ctx, right, left, vi) end """ @@ -112,77 +236,179 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, inds, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + inds, + vi, ) + # Always overwrite the parameters with new ones. + r = init(rng, dist, sampler) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) + push!(vi, vn, r, dist, sampler) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) + else + dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + end +end + +# `DefaultContext` +function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) + return dot_assume(right, vns, left, vi) +end + +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + value, logp = dot_assume(right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -193,13 +419,26 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + inds, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(zip(vns, eachcol(var))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return var, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -214,6 +453,19 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + inds, + vi, +) + # Make sure `var` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -323,18 +575,38 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) + +# `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +end +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vi) +end """ dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) @@ -366,41 +638,22 @@ function dot_tilde_observe!(ctx, sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - value::AbstractMatrix, - vi, -) +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Distribution, - value::AbstractArray, - vi, -) +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end + From 3d2e7e2b4dfb2462e6732eb3f0b3ac5494d1696f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:48 +0100 Subject: [PATCH 010/107] updated model call signature --- src/model.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7189b590e..250b89721 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(SamplingContext(rng, sampler, context), varinfo) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, sampler, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, sampler, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, sampler, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, sampler, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, sampler, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, sampler, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, sampler, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, sampler, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, sampler, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, sampler, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end From 4f1d39694083bae41ac7e94cbb19640639512879 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:24:04 +0100 Subject: [PATCH 011/107] updated compiler --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..bc906f58c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,10 +394,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -407,7 +405,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. From b187d74efcfb5b9482f39022252477b5a0bc2cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:27:37 +0100 Subject: [PATCH 012/107] formatting --- src/context_implementations.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8aa8ddfca..a0bf0381a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,9 @@ end # Leaf contexts tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, inds, vi) end @@ -184,14 +186,13 @@ function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) + return context.loglike_scalar * + tilde_observe(context.ctx, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe( - context.ctx, right, left, prefix(context, vname), vinds, vi - ) + return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.ctx, right, left, vi) @@ -296,7 +297,9 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) return if child_of_c === nothing dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) else - dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + dot_tilde_assume( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi + ) end end @@ -590,14 +593,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(right, left, vi) +end # `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * + dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -656,4 +662,3 @@ function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end - From ee99f8ce5676c5cb571417e6dd4b3d9570922bf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:30:54 +0100 Subject: [PATCH 013/107] added getsym for vectors --- src/varname.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..40c5c25e9 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,7 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From c4845d08b34b58bc30762b4be944c8924c9794e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:35:12 +0100 Subject: [PATCH 014/107] Update src/varname.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varname.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index 40c5c25e9..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -40,6 +40,5 @@ Possibly existing indices of `varname` are neglected. return s in missings end - # HACK: Type-piracy. Is this really the way to go? AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From a0c05f39315c93d9ed43cefe227255310172f339 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:43:42 +0100 Subject: [PATCH 015/107] fixed some signatures for Model --- src/model.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/model.jl b/src/model.jl index 250b89721..8d353d2de 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,7 +88,7 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return model(SamplingContext(rng, sampler, context), varinfo) + return model(varinfo, SamplingContext(rng, sampler, context)) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) @@ -115,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -124,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -140,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 307cd7e1f3a1a02bd79284ccfc640848961e2bd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:49:44 +0100 Subject: [PATCH 016/107] fixed a method call --- src/model.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8d353d2de..3a01f9bf3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,9 +94,9 @@ end (model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end @@ -151,7 +151,7 @@ end """ _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context From 597277119922a24d0783b78134a8f541eb05dc0e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:00:49 +0100 Subject: [PATCH 017/107] fixed method signatures --- src/compiler.jl | 8 ------ src/context_implementations.jl | 48 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bc906f58c..8201a82f4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -283,7 +283,6 @@ function generate_tilde(left, right) return quote $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -300,9 +299,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -312,7 +309,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -334,7 +330,6 @@ function generate_dot_tilde(left, right) return quote $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -351,9 +346,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -363,7 +356,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a0bf0381a..d6ff3b5bd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(ctx, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end @@ -415,15 +415,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -615,30 +615,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end From c4ecd0e676ab58aa13ef9995289d4368868d212c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:08:56 +0100 Subject: [PATCH 018/107] sort of fixed the matchingvalue functionality for model --- src/model.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3a01f9bf3..2d74949c1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -157,7 +157,10 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ From a34b51cd60fffe8ff45903153550ff3486680f91 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 03:36:55 +0100 Subject: [PATCH 019/107] formatting --- src/compiler.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8201a82f4..dc70ae267 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -282,10 +282,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -329,10 +326,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end From b89ff7e3a3b6c1140b155ceb6997a2cabe5479cc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 04:09:53 +0100 Subject: [PATCH 020/107] removed redundant _tilde method --- src/context_implementations.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..b5a0fd923 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -102,8 +102,6 @@ function tilde_observe!(ctx, sampler, right, left, vi) return left end -_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end From e4a2cf81154e44b23af16e473d1361cf11225fc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:02:41 +0100 Subject: [PATCH 021/107] removed left-over acclogp! that should not be here anymore --- src/context_implementations.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 42a336479..b088577f5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -345,16 +345,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(NoDist.(right), left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) end # `PriorContext` @@ -390,16 +386,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - value, logp = dot_assume(right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(right, left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, right, left, vn, inds, vi) end # `MiniBatchContext` From 7605785fff5407d558dd920720180efc0e41d885 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:04:29 +0100 Subject: [PATCH 022/107] export SamplingContext --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..3ad30972c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,6 +76,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, From 354ac52b0d2115b2df7d437407b98140a208ba5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:25 +0100 Subject: [PATCH 023/107] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 64 +++++++++++++++++----------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b088577f5..f859b4619 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(ctx, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, right, vn, inds, vi) - value, logp = tilde_assume(ctx, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, right, left, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -302,11 +302,11 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -405,15 +405,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(ctx, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -583,17 +583,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(right, left, vi) end # `MiniBatchContext` -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * - dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * + dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -605,30 +605,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, right, left, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end From b7a2b3b5b5483eb11741a9a2c7b3135abaecbcad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:46 +0100 Subject: [PATCH 024/107] formatting --- src/context_implementations.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f859b4619..a8f279804 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -591,7 +591,9 @@ end function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::MiniBatchContext, sampler, right, left, vname, vinds, vi +) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end From 9e0fc9a9eecb6f74d54e0b4a1fad9cf94c0b41eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:39:41 +0100 Subject: [PATCH 025/107] use context instead of ctx for variables --- src/contexts.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1ee43f2b2..8598fb633 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -71,7 +71,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -82,11 +82,11 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end function unwrap_childcontext(context::MiniBatchContext) @@ -109,23 +109,23 @@ unique. See also: [`@submodel`](@ref) """ struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end From 7a4a1a38ca6895c401e366e9b2707117b5ce36e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:40:18 +0100 Subject: [PATCH 026/107] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 47 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a8f279804..e66501aee 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -27,9 +27,9 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -112,11 +112,11 @@ function tilde_assume( end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end """ @@ -138,8 +138,8 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -159,8 +159,8 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -183,19 +183,19 @@ tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) return context.loglike_scalar * - tilde_observe(context.ctx, right, left, vname, vinds, vi) + tilde_observe(context.context, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.ctx, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ @@ -283,9 +283,9 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -396,12 +396,12 @@ end # `MiniBatchContext` function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -574,10 +574,10 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, right, left, vname, vinds, vi) end # Leaf contexts @@ -589,21 +589,24 @@ end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) + return context.loglike_scalar * + dot_tilde_observe(context.context, sampler, right, left, vi) end function dot_tilde_observe( context::MiniBatchContext, sampler, right, left, vname, vinds, vi ) return context.loglike_scalar * - dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) end # `PrefixContext` function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return dot_tilde_observe( + context.context, right, left, prefix(context, vname), vinds, vi + ) end function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ From 7899473512e85089f0a6497823e804c0a2a7e12c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:27 +0100 Subject: [PATCH 027/107] Update src/compiler.jl Co-authored-by: David Widmann --- src/compiler.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc70ae267..8734b72ed 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -392,12 +392,6 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = quote - # in case someone accessed these - if __context__ isa $(DynamicPPL.SamplingContext) - __rng__ = __context__.rng - __sampler__ = __context__.sampler - end - $(modelinfo[:body]) end From 1630476c742f82e3f45096923e39a4b2da6150a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:41 +0100 Subject: [PATCH 028/107] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e66501aee..4fd787c86 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -138,6 +138,7 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 6892d2b1276ef92f6765c3272673ac7a58682465 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:37:22 +0100 Subject: [PATCH 029/107] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4fd787c86..5647cd5fc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -160,6 +160,7 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 13da1b473654cb06cf621a99bc7a1904e3865102 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:00:10 +0100 Subject: [PATCH 030/107] added some whitespace to some docstrings --- src/compiler.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..2e368d32b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -54,8 +54,10 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x """ unwrap_right_vn(right, vn) + Return the unwrapped distribution on the right-hand side and variable name on the left-hand side of a `~` expression such as `x ~ Normal()`. + This is used mainly to unwrap `NamedDist` distributions. """ unwrap_right_vn(right, vn) = right, vn @@ -63,8 +65,10 @@ unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) """ unwrap_right_left_vns(right, left, vns) + Return the unwrapped distributions on the right-hand side and values and variable names on the left-hand side of a `.~` expression such as `x .~ Normal()`. + This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the variables. """ From d76e5b3d188596008d51093b2dec9b6cf3d725c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:23:12 +0100 Subject: [PATCH 031/107] deprecated tilde and dot_tilde plus exported new versions --- src/DynamicPPL.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..319798780 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -83,10 +83,12 @@ export AbstractVarInfo, PrefixContext, assume, dot_assume, - observer, + observe, dot_observe, - tilde, - dot_tilde, + tilde_assume, + tilde_observe, + dot_tilde_assume, + dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -128,4 +130,11 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# Deprecations. +@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +@deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) + +@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe(ctx, sampler, right, left, vi) + end # module From 805966979870ece251bc5304f20815a18a685ae1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:29:09 +0100 Subject: [PATCH 032/107] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/DynamicPPL.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 319798780..0eda04b28 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -131,10 +131,16 @@ include("loglikelihoods.jl") include("submodel_macro.jl") # Deprecations. -@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume( + rng, ctx, sampler, right, vn, inds, vi +) @deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) -@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) -@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe(ctx, sampler, right, left, vi) +@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume( + rng, ctx, sampler, right, left, vn, inds, vi +) +@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe( + ctx, sampler, right, left, vi +) end # module From 43ef8d1a659af36312d61e96e622f4cab39cf21e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 23:42:16 +0100 Subject: [PATCH 033/107] minor version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 96c60a14f..76b94cb49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.11.0" +version = "0.11.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 1015f0e3a248aacb3039dd4adee670504de4412f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 13:56:32 +0100 Subject: [PATCH 034/107] added impl of matchingvalue for contexts --- src/compiler.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e647df99c..352d46418 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -435,8 +435,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) -Convert the `value` to the correct type for the `sampler` and the `vi` object. +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. + +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -453,6 +457,13 @@ function matchingvalue(sampler, vi, value) end matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end + """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} From 23c86a76d27248f3b2c71eb6f72ea85439dd62b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:32:06 +0100 Subject: [PATCH 035/107] reverted the change that makes assume always resample --- src/context_implementations.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d1fd7b0ba..e5e89ca55 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -255,15 +255,23 @@ function assume( inds, vi, ) - # Always overwrite the parameters with new ones. - r = init(rng, dist, sampler) if haskey(vi, vn) - vi[vn] = vectorize(dist, r) - setorder!(vi, vn, get_num_produce(vi)) + # Always overwrite the parameters with new ones for `SampleFromUniform`. + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = init(rng, dist, sampler) + vi[vn] = vectorize(dist, r) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + r = vi[vn] + end else + r = init(rng, dist, sampler) push!(vi, vn, r, dist, sampler) + settrans!(vi, false, vn) end - settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end From 17f5abe88edefca1a1307987300197d19210da34 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:39:20 +0100 Subject: [PATCH 036/107] removed the inds arguments from assume and dot_assume to stay non-breaking --- src/context_implementations.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e5e89ca55..50e3a8e4b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -44,11 +44,11 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) function tilde_assume( rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) @@ -74,10 +74,10 @@ function tilde_assume( return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) + return assume(right, vn, vi) end function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) @@ -103,12 +103,12 @@ function tilde_assume( return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) + return assume(NoDist(right), vn, vi) end function tilde_assume( rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, NoDist(right), vn, inds, vi) + return assume(rng, sampler, NoDist(right), vn, vi) end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) @@ -238,7 +238,7 @@ function observe(spl::Sampler, weight) end # fallback without sampler -function assume(dist::Distribution, vn::VarName, inds, vi) +function assume(dist::Distribution, vn::VarName, vi) if !haskey(vi, vn) error("variable $vn does not exist") end @@ -252,7 +252,6 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - inds, vi, ) if haskey(vi, vn) @@ -355,12 +354,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - return dot_assume(NoDist.(right), left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) end # `PriorContext` @@ -396,12 +395,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, inds, vi) + return dot_assume(rng, sampler, right, left, vn, vi) end # `MiniBatchContext` @@ -433,7 +432,6 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, - inds, vi, ) @assert length(dist) == size(var, 1) @@ -460,7 +458,6 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, - inds, vi, ) # Make sure `var` is not a matrix for multivariate distributions From dbd61f04d1e5719633e36c525fc04b9190659b81 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 7 Jun 2021 15:32:41 +0100 Subject: [PATCH 037/107] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..16704a814 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -429,10 +429,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From b10ba3f17f9f8f493893f8728e5f9aaf160dfd73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:04:28 +0100 Subject: [PATCH 038/107] added missing sampler arg to tilde_observe --- src/context_implementations.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..a65336670 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -170,18 +170,19 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vi + ) end - return tilde_observe(fallback_context, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) From bc5029f45c20a8dfa777b8e370c7af16797c890c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:28 +0100 Subject: [PATCH 039/107] added missing sampler argument in dot_tilde_observe --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a65336670..04ddac2c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -583,7 +583,7 @@ probability, and return the observed value for a context associated with a sampl Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts From 7eac33dc6072023a0ac4ab59cb26f6e8b52efb32 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:49 +0100 Subject: [PATCH 040/107] fixed order of arguments in some dot_assume calls --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 04ddac2c5..77a7475a3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -360,7 +360,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` @@ -401,7 +401,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, vi) + return dot_assume(rng, sampler, right, vn, left, vi) end # `MiniBatchContext` From 85994813e350ea3e132e6f2dc0fc3d4896b4e641 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:01 +0100 Subject: [PATCH 041/107] formatting --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77a7475a3..d7d51c72a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -430,10 +430,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From 90a8c4562e10153ef9323200c59d417411ec009b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:22 +0100 Subject: [PATCH 042/107] formatting --- src/context_implementations.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7d51c72a..431af5a5e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -173,9 +173,7 @@ function tilde_observe(context::SamplingContext, right, left, vi) return if child_of_c === nothing tilde_observe(c, context.sampler, right, left, vi) else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vi - ) + tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) end end From f9d4ff83d2ed6b4e67426e636c7c1006a66cafa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:36 +0100 Subject: [PATCH 043/107] added missing sampler argument in tilde_observe for SamplingContext --- src/context_implementations.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 431af5a5e..5dce075a2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -148,12 +148,13 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vname, vinds, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi + ) end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) end """ From e424fe7b45821382b7ceb35d5294689730081334 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:08:15 +0100 Subject: [PATCH 044/107] added missing word in a docstring --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5dce075a2..b67833d62 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -614,7 +614,7 @@ end """ dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable From 70957d27c66a3a1ce7b3ac751e68709531b5e548 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:03:50 +0100 Subject: [PATCH 045/107] updated submodel macro --- src/submodel_macro.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..070a5aa4c 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,8 @@ macro submodel(expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), $(esc(:__context__)), ) end @@ -13,10 +11,8 @@ end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end From d00cdcfb0ceed552a42abaadac2a8d4413e4bd4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:58:46 +0100 Subject: [PATCH 046/107] removed unwrap_childcontext and related since its not needed for this PR --- src/context_implementations.jl | 130 +++++++++++++++------------------ src/contexts.jl | 38 ---------- 2 files changed, 58 insertions(+), 110 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b67833d62..ae4d631ae 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -29,18 +29,9 @@ Falls back to ```julia tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) - end + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) end # Leaf contexts @@ -115,10 +106,18 @@ function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) return tilde_assume(context.context, right, vn, inds, vi) end +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -139,22 +138,12 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vname, vinds, vi) - else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi - ) - end + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) end """ @@ -162,39 +151,31 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vi) - else - tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) - end + return tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * - tilde_observe(context.context, right, left, vname, vinds, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) @@ -294,25 +275,16 @@ Falls back to ```julia dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) - end + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) @@ -408,11 +380,23 @@ function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -583,30 +567,23 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) return dot_observe(right, left, vi) end +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end # `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vi) -end -function dot_tilde_observe( - context::MiniBatchContext, sampler, right, left, vname, vinds, vi -) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) end # `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe( - context.context, right, left, prefix(context, vname), vinds, vi - ) -end function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end @@ -641,18 +618,27 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) + return dot_observe(dist, value, vi) +end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" diff --git a/src/contexts.jl b/src/contexts.jl index 6daa18776..8093c88f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,17 +1,3 @@ -""" - unwrap_childcontext(context::AbstractContext) - -Return a tuple of the child context of a `context`, or `nothing` if the context does -not wrap any other context, and a function `f(c::AbstractContext)` that constructs -an instance of `context` in which the child context is replaced with `c`. - -Falls back to `(nothing, _ -> context)`. -""" -function unwrap_childcontext(context::AbstractContext) - reconstruct_context(@nospecialize(x)) = context - return nothing, reconstruct_context -end - """ SamplingContext(rng, sampler, context) @@ -26,14 +12,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end -function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) - return SamplingContext(context.rng, context.sampler, c) - end - return child, reconstruct_samplingcontext -end - """ struct DefaultContext <: AbstractContext end @@ -89,14 +67,6 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end -function unwrap_childcontext(context::MiniBatchContext) - child = context.context - function reconstruct_minibatchcontext(c::AbstractContext) - return MiniBatchContext(c, context.loglike_scalar) - end - return child, reconstruct_minibatchcontext -end - """ PrefixContext{Prefix}(context) @@ -136,11 +106,3 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context - function reconstruct_prefixcontext(c::AbstractContext) - return PrefixContext{P}(c) - end - return child, reconstruct_prefixcontext -end From 639fd6ebe36bee13fc2372915e2480858012612c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:05 +0100 Subject: [PATCH 047/107] updated submodel macro --- src/submodel_macro.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 070a5aa4c..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,6 @@ macro submodel(expr) return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end From c9a06fb46c8a9d4f184e2b64cfd9c119d68c5cc3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:15 +0100 Subject: [PATCH 048/107] fixed evaluation implementations of dot_assume --- src/context_implementations.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ae4d631ae..6befa1f6d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -416,10 +416,17 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) - lp = sum(zip(vns, eachcol(var))) do vn, ri + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return var, lp + return r, lp end function dot_assume( rng, @@ -441,9 +448,15 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - # Make sure `var` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) - return var, lp + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp end function dot_assume( From 2fe5f4016cb66ff5da3af57415dd096843363a1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:45 +0100 Subject: [PATCH 049/107] updated pointwise_loglikelihoods and related --- src/loglikelihoods.jl | 99 ++++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6fca717c6 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end From b532ca690c254b12732dcdb5f7c9a67577a763cf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:08 +0100 Subject: [PATCH 050/107] added proper tests for pointwise_loglikelihoods --- test/loglikelihoods.jl | 122 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + 2 files changed, 124 insertions(+) create mode 100644 test/loglikelihoods.jl diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..5c1fdc082 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,122 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x = 10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x = 10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) + + +@testset "loglikelihoods.jl" begin + for m in mean_of_mean_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin From 4e2274e7abffc1058f370d32f4cc93e21df1d1f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:24 +0100 Subject: [PATCH 051/107] updated DPPL tests to reflect recent changes --- test/compiler.jl | 11 +++++------ test/threadsafe.jl | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..7a2bdd039 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,14 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @model function wothreads(x) @@ -96,14 +96,14 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) end end From ef6da4377024f68bbd41683857e37fade88498f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:02:00 +0100 Subject: [PATCH 052/107] bump minor version since this will be breaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7a5ba10d..db9f26b04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.11.2" +version = "0.12.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 10899f370c5335a53b0212146cddaf69ad43e62c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:03 +0100 Subject: [PATCH 053/107] formatting --- src/context_implementations.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6befa1f6d..77cbc0fb2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -581,7 +581,9 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function dot_tilde_observe(context::LikelihoodContext, right, left, vi) @@ -631,7 +633,12 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + value::AbstractMatrix, + vi, +) return dot_observe(dist, value, vi) end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @@ -640,7 +647,12 @@ function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::Distribution, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::Distribution, value::AbstractArray, vi) @@ -649,7 +661,12 @@ function dot_observe(dists::Distribution, value::AbstractArray, vi) @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::AbstractArray{<:Distribution}, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) From 1f21ce4158f5ed256b0435eafdc08acf37b8aece Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:47 +0100 Subject: [PATCH 054/107] formatting --- test/loglikelihoods.jl | 33 +++++++++++++++++---------------- test/threadsafe.jl | 16 ++++++++++++---- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 5c1fdc082..4cc7325b8 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,28 +1,28 @@ # A collection of models for which the mean-of-means for the posterior should # be same. -@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) m[i] ~ Normal() end - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo3(x = 10 * ones(2)) +@model function gdemo3(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -33,10 +33,10 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function gdemo5(x = 10 * ones(1)) +@model function gdemo5(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() - x .~ Normal(m, 0.5) + return x .~ Normal(m, 0.5) end # @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} @@ -45,7 +45,7 @@ end # [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) # end -@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing m = TV(undef, 2) m .~ Normal() @@ -60,7 +60,7 @@ end # [10.0, ] .~ Normal(m, 0.5) # end -@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) m .~ Normal() @@ -76,10 +76,10 @@ end end @model function _likelihood_dot_observe(m, x) - x ~ MvNormal(m, 0.5 * ones(length(m))) + return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -87,8 +87,9 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) - +const mean_of_mean_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) @testset "loglikelihoods.jl" begin for m in mean_of_mean_models @@ -97,7 +98,7 @@ const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), g vns = vi.metadata.m.vns if length(vns) == 1 && length(vi[vns[1]]) == 1 # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + DynamicPPL.setval!(vi, [1.0], ["m"]) else DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 7a2bdd039..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end From 70045061b0730bac5e76ed3c6b981511734b21f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:25:56 +0100 Subject: [PATCH 055/107] renamed mean_of_mean_models used in tests --- test/loglikelihoods.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4cc7325b8..74fb88d70 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -87,12 +87,12 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = ( +const gdemo_models = ( gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() ) @testset "loglikelihoods.jl" begin - for m in mean_of_mean_models + for m in gdemo_models vi = VarInfo(m) vns = vi.metadata.m.vns From fa6c4d6aed0312b69d054c3b46f7f438b69810c3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 08:05:29 +0100 Subject: [PATCH 056/107] bumped dppl version in integration tests --- test/turing/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a4f68621d..67b8d5645 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.11" +DynamicPPL = "0.12" Turing = "0.15, 0.16" julia = "1.3" From 684d829b437e546a02d588d2204082f0be7df0ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:23:07 +0100 Subject: [PATCH 057/107] Apply suggestions from code review Co-authored-by: David Widmann --- src/compiler.jl | 4 +--- src/context_implementations.jl | 13 +++++++++---- src/contexts.jl | 2 +- src/loglikelihoods.jl | 2 +- src/model.jl | 7 ++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..2fa94bcd9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -395,9 +395,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = quote - $(modelinfo[:body]) - end + evaluatordef[:body] = modelinfo[:body] ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77cbc0fb2..259a7c1c3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -124,7 +124,8 @@ end Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(context, right, vn, inds, vi)`. +By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_assume!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) @@ -138,7 +139,10 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. +Falls back to +```julia +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +``` """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) return tilde_observe( @@ -151,7 +155,7 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)`. +Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) return tilde_observe(context.context, context.sampler, right, left, vi) @@ -202,7 +206,8 @@ end Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(context, right, left, vi)`. +By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_observe!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 8093c88f3..05ad8df0d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -4,7 +4,7 @@ Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6fca717c6..6c66e4ec4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -87,7 +87,7 @@ end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe(context.context, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. diff --git a/src/model.jl b/src/model.jl index 2d74949c1..9ec047a44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -156,11 +156,8 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return quote - sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() - model.f(model, varinfo, context, $(unwrap_args...)) - end + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 07bb28416adccab4f4783d0d708c9645f49be327 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:29:54 +0100 Subject: [PATCH 058/107] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 259a7c1c3..6833a7856 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -191,13 +191,11 @@ end Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function tilde_observe!(context, right, left, vname, vinds, vi) - logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return tilde_observe!(context, right, left, vi) end """ @@ -578,7 +576,7 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) return dot_tilde_observe(context.context, context.sampler, right, left, vi) @@ -614,13 +612,11 @@ end Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function dot_tilde_observe!(context, right, left, vn, inds, vi) - logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return dot_tilde_observe!(context, right, left, vi) end """ From c7c6a3c066e1c1d229edbe18fd8a1f7a4e8a9fc6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:56:57 +0100 Subject: [PATCH 059/107] fixed ambiguity error --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..3924eae95 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -455,7 +455,9 @@ function matchingvalue(sampler, vi, value) return value end end -matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) + return get_matching_type(sampler, vi, value) +end function matchingvalue(context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value) From 06d319c539a083bb2671c7aa146659cdcf638b16 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:36:52 +0000 Subject: [PATCH 060/107] Introduction of `SamplingContext`: keeping it simple (#259) This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc. Co-authored-by: Hong Ge --- src/DynamicPPL.jl | 1 + src/compiler.jl | 37 ++- src/context_implementations.jl | 466 ++++++++++++++++++++++++++------- src/contexts.jl | 45 +++- src/loglikelihoods.jl | 99 ++++--- src/model.jl | 38 +-- src/submodel_macro.jl | 10 +- src/varname.jl | 3 + test/compiler.jl | 11 +- test/loglikelihoods.jl | 123 +++++++++ test/runtests.jl | 2 + test/threadsafe.jl | 16 +- test/turing/Project.toml | 2 +- 13 files changed, 653 insertions(+), 200 deletions(-) create mode 100644 test/loglikelihoods.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9eb4d9675..914c0e12b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -75,6 +75,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, diff --git a/src/compiler.jl b/src/compiler.jl index 2e368d32b..7c812fb54 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -286,11 +286,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -304,9 +300,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -316,7 +310,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -449,8 +433,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) + +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. -Convert the `value` to the correct type for the `sampler` and the `vi` object. +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -465,7 +453,16 @@ function matchingvalue(sampler, vi, value) return value end end -matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) + return get_matching_type(sampler, vi, value) +end + +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 60df298b5..6833a7856 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,86 +18,197 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +``` +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +end + +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +probability of `vi` with the returned value. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. + +Falls back to +```julia +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +``` +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) +end + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + return tilde_observe(context.context, context.sampler, right, left, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left +function tilde_observe!(context, right, left, vname, vinds, vi) + return tilde_observe!(context, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log +probability of `vi` with the returned value. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -110,14 +221,28 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + vi, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) + r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) @@ -125,79 +250,187 @@ function assume( r = vi[vn] end else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) + r = init(rng, dist, sampler) + push!(vi, vn, r, dist, sampler) settrans!(vi, false, vn) end + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +``` +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) +end + +# `DefaultContext` +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) +end + +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, right, vn, left, vi) +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +end + +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return r, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -211,6 +444,24 @@ function dot_assume( lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) return r, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -319,84 +570,109 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) +end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 +dot_tilde_observe(::PriorContext, right, left, vi) = 0 +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) + return dot_observe(right, left, vi) end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `MiniBatchContext` +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left +function dot_tilde_observe!(context, right, left, vn, inds, vi) + return dot_tilde_observe!(context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi, ) + return dot_observe(dist, value, vi) +end +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end diff --git a/src/contexts.jl b/src/contexts.jl index 2c23531c6..05ad8df0d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,17 @@ +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + """ struct DefaultContext <: AbstractContext end @@ -35,7 +49,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -46,31 +60,42 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6c66e4ec4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe!(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index 7189b590e..9ec047a44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(varinfo, SamplingContext(rng, sampler, context)) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,22 +1,14 @@ macro submodel(expr) return quote - _evaluate( - $(esc(:__rng__)), - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__sampler__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,6 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..74fb88d70 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,123 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x=10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x=10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + return x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + return x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const gdemo_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) + +@testset "loglikelihoods.jl" begin + for m in gdemo_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0], ["m"]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a4f68621d..67b8d5645 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.11" +DynamicPPL = "0.12" Turing = "0.15, 0.16" julia = "1.3" From cb996c6fd002d88ec825bb0d9ca4fd428902a86f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 14:40:33 +0100 Subject: [PATCH 061/107] Update src/DynamicPPL.jl Co-authored-by: David Widmann --- src/DynamicPPL.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 914c0e12b..a46c941a1 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,17 +130,4 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -# Deprecations. -@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume( - rng, ctx, sampler, right, vn, inds, vi -) -@deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) - -@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume( - rng, ctx, sampler, right, left, vn, inds, vi -) -@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe( - ctx, sampler, right, left, vi -) - end # module From 03c9285c23b03441700e172307f3b33913a665bd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:06 +0100 Subject: [PATCH 062/107] added initial impl of SimpleVarInfo --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 105 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 src/simple_varinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..5cde57f91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -122,6 +122,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..6699a441b --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,105 @@ +""" + SimpleVarInfo{NT,T} <: AbstractVarInfo + +A simple wrapper of the parameters with a `logp` field for +accumulation of the logdensity. + +Currently only implemented for `NT <: NamedTuple`. + +## Notes +The major differences between this and `TypedVarInfo` are: +1. `SimpleVarInfo` does not require linearization. +2. `SimpleVarInfo` can use more efficient bijectors. +3. `SimpleVarInfo` only supports evaluation. +""" +struct SimpleVarInfo{NT,T} <: AbstractVarInfo + θ::NT + logp::Base.RefValue{T} +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) + +function setlogp!(vi::SimpleVarInfo, logp) + vi.logp[] = logp + return vi +end + +function acclogp!(vi::SimpleVarInfo, logp) + vi.logp[] += logp + return vi +end + +function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} + # Use `getproperty` instead of `getfield` + value = getproperty(nt, sym) + return _getindex(value, inds) +end + +getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} = _getvalue(vi.θ, Val{sym}(), vn.indexing) +# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than +# just `Vector`. +getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) +# To disambiguiate. +getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) + +haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) + +istrans(::SimpleVarInfo, vn::VarName) = false + +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ +getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) +getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) + +# Context implementations +# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. +function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) + left = vi[vn] + return left, Distributions.loglikelihood(dist, left) +end + +# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) +# end + +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::SimpleVarInfo, +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = vi[vns] + lp = sum(zip(vns, eachcol(r))) do vn, ri + return Distributions.logpdf(dist, ri) + end + return r, lp +end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi::SimpleVarInfo{<:NamedTuple}, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = vi[vns] + lp = sum(Distributions.logpdf.(dists, r)) + return r, lp +end + +# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. +increment_num_produce!(::SimpleVarInfo) = nothing From f91952d9f71641c9a69896f4363cff881b17517b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:16 +0100 Subject: [PATCH 063/107] remove unnecessary debug statements to be compat with Zygote --- src/context_implementations.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6833a7856..ab8fc7cab 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -644,8 +644,6 @@ function dot_observe( end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) - @debug "dist = $dist" - @debug "value = $value" return Distributions.loglikelihood(dist, value) end function dot_observe( @@ -658,8 +656,6 @@ function dot_observe( end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) - @debug "dists = $dists" - @debug "value = $value" return Distributions.loglikelihood(dists, value) end function dot_observe( @@ -672,7 +668,5 @@ function dot_observe( end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) - @debug "dists = $dists" - @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end From 4d4b4893085551be951f1c7d8c17c6767d91663d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:27 +0100 Subject: [PATCH 064/107] make reconstruct slightly more generic --- src/utils.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..95b7f6a9a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,11 +93,10 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). -reconstruct(d::UnivariateDistribution, val::AbstractVector) = val[1] -reconstruct(d::MultivariateDistribution, val::AbstractVector) = copy(val) -function reconstruct(d::MatrixDistribution, val::AbstractVector) - return reshape(copy(val), size(d)) -end +reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) +reconstruct(::Tuple{}, val::AbstractVector) = val[1] +reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) +reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) function reconstruct!(r, d::Distribution, val::AbstractVector) return reconstruct!(r, d, val) end @@ -106,17 +105,17 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector) return r end function reconstruct(d::Distribution, val::AbstractVector, n::Int) - return reconstruct(d, val, n) + return reconstruct(size(d), val, n) end -function reconstruct(d::UnivariateDistribution, val::AbstractVector, n::Int) +function reconstruct(::Tuple{}, val::AbstractVector, n::Int) return copy(val) end -function reconstruct(d::MultivariateDistribution, val::AbstractVector, n::Int) - return copy(reshape(val, size(d)[1], n)) +function reconstruct(s::NTuple{1}, val::AbstractVector, n::Int) + return copy(reshape(val, s[1], n)) end -function reconstruct(d::MatrixDistribution, val::AbstractVector, n::Int) - tmp = reshape(val, size(d)[1], size(d)[2], n) - orig = [tmp[:, :, i] for i in 1:size(tmp, 3)] +function reconstruct(s::NTuple{2}, val::AbstractVector, n::Int) + tmp = reshape(val, s..., n) + orig = [tmp[:, :, i] for i in 1:n] return orig end function reconstruct!(r, d::Distribution, val::AbstractVector, n::Int) From a68c045ead9ad23acd0dcd8fdf41bd3a8174a3b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 10:04:23 +0100 Subject: [PATCH 065/107] added a couple of convenience constructors --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5cde57f91..8659c3b4e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -32,6 +32,7 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + SimpleVarInfo, getlogp, setlogp!, acclogp!, diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6699a441b..3a2b01de1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -18,7 +18,7 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) -SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) +SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) function setlogp!(vi::SimpleVarInfo, logp) vi.logp[] = logp @@ -103,3 +103,20 @@ end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing + +# Interaction with `VarInfo` +SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real, names} + vals = map(names) do n + let md = getfield(vi.metadata, n) + x = map(enumerate(md.ranges)) do (i, r) + reconstruct(md.dists[i], md.vals[r]) + end + + # TODO: Doesn't support batches of `MultivariateDistribution`? + length(x) == 1 ? x[1] : x + end + end + + return SimpleVarInfo{T}(NamedTuple{names}(vals)) +end From 9766aecb82d9e5c9419047b14c96b6e78a1ff595 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 14:37:10 +0100 Subject: [PATCH 066/107] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3a2b01de1..1b46e0d0a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,7 +36,9 @@ function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} return _getindex(value, inds) end -getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} = _getvalue(vi.θ, Val{sym}(), vn.indexing) +function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} + return _getvalue(vi.θ, Val{sym}(), vn.indexing) +end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) @@ -106,7 +108,7 @@ increment_num_produce!(::SimpleVarInfo) = nothing # Interaction with `VarInfo` SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real, names} +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} vals = map(names) do n let md = getfield(vi.metadata, n) x = map(enumerate(md.ranges)) do (i, r) From 46b1c7884b8bd1fadb9fc5e8701b647b21508c57 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 14:41:31 +0100 Subject: [PATCH 067/107] small fix --- src/simple_varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1b46e0d0a..ff7b8a7c2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -51,9 +51,12 @@ istrans(::SimpleVarInfo, vn::VarName) = false getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +# TODO: Should we do better? getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) +# HACK: Need to disambiguiate. +getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Context implementations # Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. From 3a645d623d99e35f897f56f807e2b77f528f2e43 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:00:47 +0100 Subject: [PATCH 068/107] return var_info from tilde-statements, allowing impl of immutable versions --- src/compiler.jl | 35 +++++++++++++++++++++++++++------- src/context_implementations.jl | 13 +++++-------- src/model.jl | 13 +++++++++++++ src/utils.jl | 2 +- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..920262002 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -299,7 +299,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -313,7 +313,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $left, __varinfo__ = $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn @@ -322,7 +322,7 @@ function generate_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -343,7 +343,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -357,7 +357,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn @@ -366,7 +366,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo = $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -378,6 +378,27 @@ function generate_dot_tilde(left, right) end end +replace_returns(e) = e +replace_returns(e::Symbol) = e +function replace_returns(e::Expr) + if Meta.isexpr(e, :function) || Meta.isexpr(e, :->) + return e + end + + if Meta.isexpr(e, :return) + retval = if length(e.args) > 1 + Expr(:tuple, e.args...) + else + e.args[1] + end + return quote + return $retval, __varinfo__ + end + end + + return Expr(e.head, map(x -> replace_returns(x), e.args)...) +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true @@ -409,7 +430,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = replace_returns(modelinfo[:body]) ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..b48520d03 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,8 +129,7 @@ probability of `vi` with the returned value. """ function tilde_assume!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) - acclogp!(vi, logp) - return value + return value, acclogp!(vi, logp) end # observe @@ -213,8 +212,7 @@ probability of `vi` with the returned value. """ function tilde_observe!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -415,8 +413,8 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ function dot_tilde_assume!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + left .= value + return value, acclogp!(vi, logp) end # `dot_assume` @@ -634,8 +632,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!(vi, logp) end # Falls back to non-sampler definition. diff --git a/src/model.jl b/src/model.jl index 9ec047a44..6fe2dcfa7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -155,6 +155,19 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context +) where {_F,argnames} + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :((first ∘ model.f)(model, varinfo, context, $(unwrap_args...))) +end + +""" + _evaluate_with_varinfo(model::Model, varinfo, context) + +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object, +also returning the resulting `varinfo`. +""" +@generated function _evaluate_with_varinfo( + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] return :(model.f(model, varinfo, context, $(unwrap_args...))) diff --git a/src/utils.jl b/src/utils.jl index 95b7f6a9a..2ff537fe4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!($(esc(:(__varinfo__))), $(esc(ex))) end end From a2ec0bd4c3cb51ab9aff759012addf76f7b39fd2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:01:14 +0100 Subject: [PATCH 069/107] allow usage of non-Ref types in SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ff7b8a7c2..696da63f3 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -14,18 +14,22 @@ The major differences between this and `TypedVarInfo` are: """ struct SimpleVarInfo{NT,T} <: AbstractVarInfo θ::NT - logp::Base.RefValue{T} + logp::T end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -function setlogp!(vi::SimpleVarInfo, logp) +getlogp(vi::SimpleVarInfo{<:Any, <:Real}) = vi.logp +setlogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) + +function setlogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo, logp) +function acclogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) vi.logp[] += logp return vi end From 1d9bc373cdb23fd70cca8cae26bc54af9a0d157e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:01:34 +0100 Subject: [PATCH 070/107] update submodel-macro --- src/submodel_macro.jl | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 1d574e286..96f1af6e6 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,15 +1,33 @@ macro submodel(expr) - return quote - _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + args_tilde = getargs_tilde(expr) + return if args_tilde === nothing + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + end + else + # Here we also want the return-variable. + L, R = args_tilde + quote + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__))) + end end end macro submodel(prefix, expr) - return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), - ) + ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) + + args_tilde = getargs_tilde(expr) + return if args_tilde === nothing + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + end + else + # Here we also want the return-variable. + L, R = args_tilde + quote + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(ctx)) + end end end From cfd7f219504bd1fdb3082032cb741e11644883a0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:26:44 +0100 Subject: [PATCH 071/107] formatting and docstring for submodel-macro --- src/simple_varinfo.jl | 10 +++++----- src/submodel_macro.jl | 26 ++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 696da63f3..2f029925b 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -20,16 +20,16 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -getlogp(vi::SimpleVarInfo{<:Any, <:Real}) = vi.logp -setlogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +getlogp(vi::SimpleVarInfo{<:Any,<:Real}) = vi.logp +setlogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) -function setlogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) +function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) +function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] += logp return vi end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 96f1af6e6..917d80cc4 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,15 +1,29 @@ +""" + @submodel x ~ model(args...) + @submodel prefix x ~ model(args...) + +Treats `model` as a distribution, where `x` is the return-value of `model`. + +If `prefix` is specified, then variables sampled within `model` will be +prefixed by `prefix`. This is useful if you have variables of same names in +several models used together. +""" macro submodel(expr) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__)) + ) end else # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__))) + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__)) + ) end end end @@ -21,13 +35,17 @@ macro submodel(prefix, expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(expr)), $(esc(:__varinfo__)), $(ctx) + ) end else # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(ctx)) + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(R)), $(esc(:__varinfo__)), $(ctx) + ) end end end From c200e7362eb0717a81ea235dc2c17fc7ff07b357 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:42:51 +0100 Subject: [PATCH 072/107] attempt at supporting implicit returns too --- src/compiler.jl | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 920262002..1a9dc8929 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -386,12 +386,16 @@ function replace_returns(e::Expr) end if Meta.isexpr(e, :return) - retval = if length(e.args) > 1 + retval_expr = if length(e.args) > 1 Expr(:tuple, e.args...) else e.args[1] end + # Use intermediate variable since this expression + # can be more complex than just a value, e.g. `return if ... end`. + @gensym retval return quote + $retval = $retval_expr return $retval, __varinfo__ end end @@ -399,6 +403,18 @@ function replace_returns(e::Expr) return Expr(e.head, map(x -> replace_returns(x), e.args)...) end +# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. +make_returns_explicit!(body) = Expr(:return, body) +function make_returns_explicit!(body::Expr) + # If it's already a return-statement, we return immediately. + if Meta.isexpr(body, :return) + return body + end + + body.args[end] = Expr(:return, body.args[end]) + return body +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true @@ -430,7 +446,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = replace_returns(modelinfo[:body]) + evaluatordef[:body] = replace_returns(make_returns_explicit!(modelinfo[:body])) ## Build the model function. From efeb812c750ed0517e81f4b35dbcedf471f577fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:44:14 +0100 Subject: [PATCH 073/107] added a small comment --- src/compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler.jl b/src/compiler.jl index 1a9dc8929..200285f02 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -411,6 +411,7 @@ function make_returns_explicit!(body::Expr) return body end + # Otherwise we replace the last statement with a `return` statement. body.args[end] = Expr(:return, body.args[end]) return body end From 14b94956bdb14587b7321267dba20d98d54b0171 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:47:56 +0100 Subject: [PATCH 074/107] simplifed submodel macro a bit --- src/submodel_macro.jl | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 917d80cc4..8e59f3015 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -9,28 +9,15 @@ prefixed by `prefix`. This is useful if you have variables of same names in several models used together. """ macro submodel(expr) - args_tilde = getargs_tilde(expr) - return if args_tilde === nothing - # In this case we only want to get the `__varinfo__`. - quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( - $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__)) - ) - end - else - # Here we also want the return-variable. - L, R = args_tilde - quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( - $(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__)) - ) - end - end + return submodel(expr) end macro submodel(prefix, expr) ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) + return submodel(expr, ctx) +end +function submodel(expr, ctx = esc(:__context__)) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. From c3d9e7b09aa3bec696b6ab06617d3088342bad66 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 03:31:51 +0100 Subject: [PATCH 075/107] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/submodel_macro.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 8e59f3015..32a2bd583 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -17,7 +17,7 @@ macro submodel(prefix, expr) return submodel(expr, ctx) end -function submodel(expr, ctx = esc(:__context__)) +function submodel(expr, ctx=esc(:__context__)) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. From 416e7736be46b203e71f0660c544de989a78aa9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:51:17 +0100 Subject: [PATCH 076/107] fixed typo --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 200285f02..edf037b03 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -406,8 +406,8 @@ end # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) - # If it's already a return-statement, we return immediately. - if Meta.isexpr(body, :return) + # If the last statement is a return-statement, we don't do anything. + if Meta.isexpr(body.args[end], :return) return body end From b4b8b03edede2d3fee38c7be680f7117df8ff9b4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:54:11 +0100 Subject: [PATCH 077/107] use bang-bang convention --- src/compiler.jl | 12 ++++++------ src/context_implementations.jl | 30 +++++++++++++++--------------- src/loglikelihoods.jl | 16 ++++++++-------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index edf037b03..2b7945c9c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -299,7 +299,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - _, __varinfo__ = $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -313,7 +313,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left, __varinfo__ = $(DynamicPPL.tilde_assume!)( + $left, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn @@ -322,7 +322,7 @@ function generate_tilde(left, right) __varinfo__, ) else - _, __varinfo__ = $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -343,7 +343,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -357,7 +357,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn @@ -366,7 +366,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - _, __varinfo = $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b48520d03..347d3403d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -119,7 +119,7 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. @@ -127,7 +127,7 @@ accumulate the log probability, and return the sampled value. By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) +function tilde_assume!!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) return value, acclogp!(vi, logp) end @@ -189,16 +189,16 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) - return tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vname, vinds, vi) + return tilde_observe!!(context, right, left, vi) end """ @@ -210,7 +210,7 @@ return the observed value. By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) return left, acclogp!(vi, logp) end @@ -404,14 +404,14 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) +function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) left .= value return value, acclogp!(vi, logp) @@ -610,27 +610,27 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) - return dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vn, inds, vi) + return dot_tilde_observe!!(context, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vi) + dot_tilde_observe!!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) return left, acclogp!(vi, logp) end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..4b1a16486 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -69,13 +69,13 @@ function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return tilde_observe!(context.context, right, left, vi) + return tilde_observe!!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. + # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) @@ -85,13 +85,13 @@ function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi return left end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe!(context.context, right, left, vi) + return dot_tilde_observe!!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. + # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) From a725a27c1a84c38de3e5921f76d3cb6897f93449 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:55:03 +0100 Subject: [PATCH 078/107] updated PointwiseLikelihoodContext --- src/loglikelihoods.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 4b1a16486..a12c8103c 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -77,12 +77,11 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, v # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!(vi, logp) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) @@ -93,12 +92,11 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!(vi, logp) end """ From 5512670ad6a1aa1d9b6e5df544202d3039d1e65c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:24:31 +0100 Subject: [PATCH 079/107] fixed issue where we unnecessarily replace the return-statement --- src/compiler.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2b7945c9c..09d3e9c9f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -386,6 +386,9 @@ function replace_returns(e::Expr) end if Meta.isexpr(e, :return) + # NOTE: `return` always has an argument. In the case of + # `return`, the parsed expression will be `return nothing`. + # Hence we don't need any special handling for empty returns. retval_expr = if length(e.args) > 1 Expr(:tuple, e.args...) else @@ -394,9 +397,15 @@ function replace_returns(e::Expr) # Use intermediate variable since this expression # can be more complex than just a value, e.g. `return if ... end`. @gensym retval + + # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return $retval, __varinfo__ + return if $retval isa Tuple{<:Any, $(DynamicPPL.AbstractVarInfo)} + $retval + else + $retval, __varinfo__ + end end end From 4c1ee70489b4cea67cb8a5795d7f664e4625897d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:27:17 +0100 Subject: [PATCH 080/107] check subtype in the retval --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 09d3e9c9f..937d6a368 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any, $(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{<:Any, <:$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From 26590b5c1f59851e2b748f7b83a35356fc094ed5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:27:34 +0100 Subject: [PATCH 081/107] formatting --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 937d6a368..561ea25af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any, <:$(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{<:Any,<:$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From 42fd4144fb0dc7fc435553271a984d81655e56f4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 14:19:05 +0100 Subject: [PATCH 082/107] fixed type-instability in retval check --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 561ea25af..1cba5181b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any,<:$(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{Any,$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From f219545c49b59aff2665b95d5b3caab7cd2e1b89 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 14:53:39 +0100 Subject: [PATCH 083/107] introduced evaluate method for model --- src/model.jl | 60 ++++++++++++++++++++++--------------------- src/submodel_macro.jl | 4 +-- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/model.jl b/src/model.jl index 6fe2dcfa7..d169aac8e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -82,17 +82,20 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( - rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return model(varinfo, SamplingContext(rng, sampler, context)) -end +(model::Model)(args...) = (first ∘ evaluate)(model, args...) + +""" + evaluate(model::Model[, rng, varinfo, sampler, context]) -(model::Model)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +Sample from the `model` using the `sampler` with random number generator `rng` and the +`context`, and store the sample and log joint probability in `varinfo`. + +Returns both the return-value of the original model, and the resulting varinfo. + +The method resets the log joint probability of `varinfo` and increases the evaluation +number of `sampler`. +""" +function evaluate(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -100,18 +103,30 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) - return model(Random.GLOBAL_RNG, args...) +function evaluate( + model::Model, + rng::Random.AbstractRNG, + varinfo::AbstractVarInfo=VarInfo(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return evaluate(model, varinfo, SamplingContext(rng, sampler, context)) +end + +evaluate(model::Model, context::AbstractContext) = evaluate(model, VarInfo(), context) + +function evaluate(model::Model, args...) + return evaluate(model, Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) - return model(rng, VarInfo(), sampler, args...) +function evaluate(model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...) + return evaluate(model, rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) - return model(rng, VarInfo(), SampleFromPrior(), context) +function evaluate(model::Model, rng::Random.AbstractRNG, context::AbstractContext) + return evaluate(model, rng, VarInfo(), SampleFromPrior(), context) end """ @@ -155,19 +170,6 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :((first ∘ model.f)(model, varinfo, context, $(unwrap_args...))) -end - -""" - _evaluate_with_varinfo(model::Model, varinfo, context) - -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object, -also returning the resulting `varinfo`. -""" -@generated function _evaluate_with_varinfo( - model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] return :(model.f(model, varinfo, context, $(unwrap_args...))) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 32a2bd583..f9356a3c9 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -22,7 +22,7 @@ function submodel(expr, ctx=esc(:__context__)) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( $(esc(expr)), $(esc(:__varinfo__)), $(ctx) ) end @@ -30,7 +30,7 @@ function submodel(expr, ctx=esc(:__context__)) # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(L)), $(esc(:__varinfo__)) = _evaluate( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) end From ce1356629de825529544c1f9bfb0bbaa8493328a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 15:51:16 +0100 Subject: [PATCH 084/107] remove unnecessary type-requirement --- src/simple_varinfo.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2f029925b..4e848e291 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -20,9 +20,9 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -getlogp(vi::SimpleVarInfo{<:Any,<:Real}) = vi.logp -setlogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +getlogp(vi::SimpleVarInfo) = vi.logp +setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp From 3556b111e612895a3848418f5db83b7c9a0ce7e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 16:00:21 +0100 Subject: [PATCH 085/107] make return-value check much nicer --- src/compiler.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1cba5181b..af7100772 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,24 +394,16 @@ function replace_returns(e::Expr) else e.args[1] end - # Use intermediate variable since this expression - # can be more complex than just a value, e.g. `return if ... end`. - @gensym retval - # If the return-value is already of the form we want, we don't do anything. - return quote - $retval = $retval_expr - return if $retval isa Tuple{Any,$(DynamicPPL.AbstractVarInfo)} - $retval - else - $retval, __varinfo__ - end - end + return :($(DynamicPPL.return_values)($retval_expr, __varinfo__)) end return Expr(e.head, map(x -> replace_returns(x), e.args)...) end +return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) +return_values(retval::Tuple{Any,AbstractVarInfo}, ::AbstractVarInfo) = retval + # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) From 599d09443a148f366f9a0f09f3defbeb3d5cd168 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 16:14:16 +0100 Subject: [PATCH 086/107] removed redundant creation of anonymous function --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index af7100772..13318ff87 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -398,7 +398,7 @@ function replace_returns(e::Expr) return :($(DynamicPPL.return_values)($retval_expr, __varinfo__)) end - return Expr(e.head, map(x -> replace_returns(x), e.args)...) + return Expr(e.head, map(replace_returns, e.args)...) end return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) From 22b170c09a33f617c88e0d640c73c8c8594f5490 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Jun 2021 21:24:42 +0100 Subject: [PATCH 087/107] dont use UnionAll in return_values --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 13318ff87..bea869712 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -402,7 +402,7 @@ function replace_returns(e::Expr) end return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) -return_values(retval::Tuple{Any,AbstractVarInfo}, ::AbstractVarInfo) = retval +return_values(retval::Tuple{<:Any,<:AbstractVarInfo}, ::AbstractVarInfo) = retval # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) From 4606f163b867922e5e1bf7b0da8d3d3f156f8bdf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Jun 2021 22:39:44 +0100 Subject: [PATCH 088/107] updated tests for submodel to reflect new syntax --- test/compiler.jl | 6 +++--- test/loglikelihoods.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 6f85e9453..703027fda 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -364,8 +364,8 @@ end end @model function demo_useval(x, y) - x1 = @submodel sub1 demo_return(x) - x2 = @submodel sub2 demo_return(y) + @submodel sub1 x1 ~ demo_return(x) + @submodel sub2 x2 ~ demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -399,7 +399,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) + @submodel $(Symbol("ar1_$i")) x ~ AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 74fb88d70..f1ded1a0f 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -69,7 +69,7 @@ end @model function gdemo9() # Submodel prior - m = @submodel _prior_dot_assume() + @submodel m ~ _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end From 68cb021885b1c0f603fd08e3e774fef1d42bf2bc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 14:50:18 +0100 Subject: [PATCH 089/107] moved to using BangBang-convention for most methods --- Project.toml | 1 + src/DynamicPPL.jl | 35 +++++++++ src/compat/ad.jl | 4 +- src/context_implementations.jl | 48 ++++++------- src/loglikelihoods.jl | 8 +-- src/model.jl | 6 +- src/submodel_macro.jl | 1 + src/threadsafe.jl | 30 ++++---- src/utils.jl | 2 +- src/varinfo.jl | 125 +++++++++++++++++---------------- test/threadsafe.jl | 6 +- test/varinfo.jl | 30 ++++---- 12 files changed, 167 insertions(+), 129 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8659c3b4e..88a5ad89f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -27,16 +27,25 @@ import Base: keys, haskey +import BangBang: + push!!, + empty!! + # VarInfo export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, + push!!, + empty!!, getlogp, setlogp!, acclogp!, resetlogp!, + setlogp!!, + acclogp!!, + resetlogp!!, get_num_produce, set_num_produce!, reset_num_produce!, @@ -45,12 +54,18 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, + set_flag!!, + unset_flag!!, setgid!, updategid!, + setgid!!, + updategid!!, setorder!, istrans, link!, invlink!, + link!!, + invlink!!, tonamedtuple, # VarName (reexport from AbstractPPL) VarName, @@ -132,4 +147,24 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# Deprecations +@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) + +@deprecate setlogp!(vi, logp) setlogp!!(vi, logp) +@deprecate acclogp!(vi, logp) acclogp!!(vi, logp) +@deprecate resetlogp!(vi) resetlogp!!(vi) + +@deprecate link!(vi, spl) link!!(vi, spl) +@deprecate invlink!(vi, spl) invlink!!(vi, spl) + +@deprecate set_flag!(vi, vn, flag) set_flag!!(vi, vn, flag) +@deprecate unset_flag!(vi, vn, flag) unset_flag!!(vi, vn, flag) + +@deprecate setgid!(vi, gid, vn) setgid!!(vi, gid, vn) +@deprecate updategid!(vi, vn, spl) updategid!!(vi, vn, spl) + end # module diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 47a627506..664ce2b33 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -1,9 +1,9 @@ # See https://github.com/TuringLang/Turing.jl/issues/1199 -ChainRulesCore.@non_differentiable push!( +ChainRulesCore.@non_differentiable push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) -ChainRulesCore.@non_differentiable updategid!( +ChainRulesCore.@non_differentiable updategid!!( vi::AbstractVarInfo, vn::VarName, spl::Sampler ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 347d3403d..64a958644 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, inds, vi) end @@ -60,7 +60,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end @@ -74,7 +74,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end @@ -89,7 +89,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end @@ -129,7 +129,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) - return value, acclogp!(vi, logp) + return value, acclogp!!(vi, logp) end # observe @@ -212,7 +212,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -243,18 +243,18 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + unset_flag!!(vi, vn, "del") r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else r = vi[vn] end else r = init(rng, dist, sampler) - push!(vi, vn, r, dist, sampler) - settrans!(vi, false, vn) + push!!(vi, vn, r, dist, sampler) + settrans!!(vi, false, vn) end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) @@ -305,7 +305,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) @@ -325,7 +325,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) @@ -346,7 +346,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) @@ -366,7 +366,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) @@ -414,7 +414,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) left .= value - return value, acclogp!(vi, logp) + return value, acclogp!!(vi, logp) end # `dot_assume` @@ -495,12 +495,12 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + unset_flag!!(vi, vns[1], "del") r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -510,8 +510,8 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!(vi, vn, r[:, i], dist, spl) - settrans!(vi, false, vn) + push!!(vi, vn, r[:, i], dist, spl) + settrans!!(vi, false, vn) end end return r @@ -527,14 +527,14 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + unset_flag!!(vi, vns[1], "del") f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -543,8 +543,8 @@ function get_and_set_val!( else f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) - push!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + push!!.(Ref(vi), vns, r, dists, Ref(spl)) + settrans!!.(Ref(vi), false, vns) end return r end @@ -632,7 +632,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end # Falls back to non-sampler definition. diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index a12c8103c..4ca015f3e 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -74,14 +74,14 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) return tilde_observe!!(context.context, right, left, vi) end function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # Need the `logp` value, so we cannot defer `acclogp!!` to child-context, i.e. # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) @@ -89,14 +89,14 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v return dot_tilde_observe!!(context.context, right, left, vi) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # Need the `logp` value, so we cannot defer `acclogp!!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end """ diff --git a/src/model.jl b/src/model.jl index d169aac8e..929f420ff 100644 --- a/src/model.jl +++ b/src/model.jl @@ -140,7 +140,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ function evaluate_threadunsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) return _evaluate(model, varinfo, context) end @@ -156,10 +156,10 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ function evaluate_threadsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) result = _evaluate(model, wrapper, context) - setlogp!(varinfo, getlogp(wrapper)) + setlogp!!(varinfo, getlogp(wrapper)) return result end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f9356a3c9..25460eada 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -28,6 +28,7 @@ function submodel(expr, ctx=esc(:__context__)) end else # Here we also want the return-variable. + # TODO: Should we prefix by `L` by default? L, R = args_tilde quote $(esc(L)), $(esc(:__varinfo__)) = _evaluate( diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c940f9e3f..9c59fa507 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. -function acclogp!(vi::ThreadSafeVarInfo, logp) +function acclogp!!(vi::ThreadSafeVarInfo, logp) vi.logps[Threads.threadid()][] += logp return vi end @@ -26,17 +26,17 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps) # TODO: Make remaining methods thread-safe. -function resetlogp!(vi::ThreadSafeVarInfo) +function resetlogp!!(vi::ThreadSafeVarInfo) for x in vi.logps x[] = zero(x[]) end - return resetlogp!(vi.varinfo) + return resetlogp!!(vi.varinfo) end -function setlogp!(vi::ThreadSafeVarInfo, logp) +function setlogp!!(vi::ThreadSafeVarInfo, logp) for x in vi.logps x[] = zero(x[]) end - return setlogp!(vi.varinfo, logp) + return setlogp!!(vi.varinfo, logp) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -46,8 +46,8 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!(vi.varinfo, gid, vn) +function setgid!!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) + return setgid!!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -55,8 +55,8 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) +link!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!!(vi.varinfo, spl) +invlink!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) @@ -80,20 +80,20 @@ function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function empty!(vi::ThreadSafeVarInfo) - empty!(vi.varinfo) +function empty!!(vi::ThreadSafeVarInfo) + empty!!(vi.varinfo) fill!(vi.logps, zero(getlogp(vi))) return vi end -function push!( +function push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return push!(vi.varinfo, vn, r, dist, gidset) + return push!!(vi.varinfo, vn, r, dist, gidset) end -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) +function unset_flag!!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) + return unset_flag!!(vi.varinfo, vn, flag) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/utils.jl b/src/utils.jl index 2ff537fe4..14b7650fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..1f81f692d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -335,14 +335,15 @@ getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) end """ - setall!(vi::VarInfo, val) + setall!!(vi::VarInfo, val) -Set the values of all the variables in `vi` to `val`. +Set the values of all the variables in `vi` to `val`, +mutating if it makese sense. The values may or may not be transformed to Euclidean space. """ -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) +setall!!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val +setall!!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} expr = Expr(:block) start = :(1) @@ -363,12 +364,12 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool, vn::VarName) -Set the `trans` flag value of `vn` in `vi`. +Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + return trans ? set_flag!!(vi, vn, "trans") : unset_flag!!(vi, vn, "trans") end """ @@ -504,11 +505,11 @@ end end """ - set_flag!(vi::VarInfo, vn::VarName, flag::String) + set_flag!!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `true` in `vi`. """ -function set_flag!(vi::VarInfo, vn::VarName, flag::String) +function set_flag!!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end @@ -586,16 +587,16 @@ end TypedVarInfo(vi::TypedVarInfo) = vi """ - empty!(vi::VarInfo) + empty!!(vi::VarInfo) Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. +zeros, mutating if it makes sense. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ -function empty!(vi::VarInfo) +function empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!(vi) + resetlogp!!(vi) reset_num_produce!(vi) return vi end @@ -628,11 +629,11 @@ Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) end """ - setgid!(vi::VarInfo, gid::Selector, vn::VarName) + setgid!!(vi::VarInfo, gid::Selector, vn::VarName) Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!(vi::VarInfo, gid::Selector, vn::VarName) +function setgid!!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end @@ -653,34 +654,34 @@ Return the log of the joint probability of the observed data and parameters samp getlogp(vi::AbstractVarInfo) = vi.logp[] """ - setlogp!(vi::VarInfo, logp) + setlogp!!(vi::VarInfo, logp) Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`. +`vi` to `logp`, mutating if it makes sense. """ -function setlogp!(vi::VarInfo, logp) +function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end """ - acclogp!(vi::VarInfo, logp) + acclogp!!(vi::VarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`. +parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!(vi::VarInfo, logp) +function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end """ - resetlogp!(vi::AbstractVarInfo) + resetlogp!!(vi::AbstractVarInfo) Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0. +sampled in `vi` to 0, mutating if it makes sense. """ -resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) """ get_num_produce(vi::VarInfo) @@ -728,13 +729,13 @@ end # X -> R for all variables associated with given sampler """ - link!(vi::VarInfo, spl::Sampler) + link!!(vi::VarInfo, spl::Sampler) Transform the values of the random variables sampled by `spl` in `vi` from the support of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function link!!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -747,16 +748,16 @@ function link!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - return link!(vi, spl, Val(getspace(spl))) +function link!!(vi::TypedVarInfo, spl::AbstractSampler) + return link!!(vi, spl, Val(getspace(spl))) end -function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function link!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -783,7 +784,7 @@ end ), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -797,13 +798,13 @@ end # R -> X for all variables associated with given sampler """ - invlink!(vi::VarInfo, spl::AbstractSampler) + invlink!!(vi::VarInfo, spl::AbstractSampler) Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @@ -814,16 +815,16 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return invlink!(vi, spl, Val(getspace(spl))) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler) + return invlink!!(vi, spl, Val(getspace(spl))) end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end @@ -852,7 +853,7 @@ end ), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -962,7 +963,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) +setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` @@ -1086,42 +1087,42 @@ function Base.show(io::IO, vi::UntypedVarInfo) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`. +the `VarInfo` `vi`, mutating if it makes sense. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!(vi, vn, r, dist, Set{Selector}([])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return push!!(vi, vn, r, dist, Set{Selector}([])) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`. +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - return push!(vi, vn, r, dist, spl.selector) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) + return push!!(vi, vn, r, dist, spl.selector) end -function push!( +function push!!( vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler ) - return push!(vi, vn, r, dist) + return push!!(vi, vn, r, dist) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!(vi, vn, r, dist, Set([gid])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + return push!!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo @@ -1174,11 +1175,11 @@ function is_flagged(vi::VarInfo, vn::VarName, flag::String) end """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `false` in `vi`. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) +function unset_flag!!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false end @@ -1238,14 +1239,14 @@ end end """ - updategid!(vi::VarInfo, vn::VarName, spl::Sampler) + updategid!!(vi::VarInfo, vn::VarName, spl::Sampler) Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked and `vn`'s symbol is in the space of `spl`. """ -function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) +function updategid!!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) if inspace(vn, getspace(spl)) - setgid!(vi, spl.selector, vn) + setgid!!(vi, spl.selector, vn) end end @@ -1393,7 +1394,7 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return indices @@ -1474,11 +1475,11 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. - set_flag!(vi, vn, "del") + set_flag!!(vi, vn, "del") end return indices diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 83c53ccd6..bd1f4f154 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -17,17 +17,17 @@ lp = getlogp(vi) @test getlogp(threadsafe_vi) == lp - acclogp!(threadsafe_vi, 42) + acclogp!!(threadsafe_vi, 42) @test threadsafe_vi.logps[Threads.threadid()][] == 42 @test getlogp(vi) == lp @test getlogp(threadsafe_vi) == lp + 42 - resetlogp!(threadsafe_vi) + resetlogp!!(threadsafe_vi) @test iszero(getlogp(vi)) @test iszero(getlogp(threadsafe_vi)) @test all(iszero(x[]) for x in threadsafe_vi.logps) - setlogp!(threadsafe_vi, 42) + setlogp!!(threadsafe_vi, 42) @test getlogp(vi) == 42 @test getlogp(threadsafe_vi) == 42 @test all(iszero(x[]) for x in threadsafe_vi.logps) diff --git a/test/varinfo.jl b/test/varinfo.jl index 4c8ec43cb..f1cadfa8f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -33,7 +33,7 @@ end @testset "Base" begin # Test Base functions: - # string, Symbol, ==, hash, in, keys, haskey, isempty, push!, empty!, + # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! csym = gensym() vn1 = @varname x[1][2] @@ -46,7 +46,7 @@ @test inspace(vn1, (:x,)) function test_base!(vi) - empty!(vi) + empty!!(vi) @test getlogp(vi) == 0 @test get_num_produce(vi) == 0 @@ -58,7 +58,7 @@ @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -75,9 +75,9 @@ @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!(vi) + empty!!(vi) @test isempty(vi) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) function test_inspace() space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) @@ -98,7 +98,7 @@ end vi = VarInfo() test_base!(vi) - test_base!(empty!(TypedVarInfo(vi))) + test_base!(empty!!(TypedVarInfo(vi))) end @testset "flags" begin # Test flag setting: @@ -109,20 +109,20 @@ r = rand(dist) gid = Selector() - push!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist, gid) # del is set by default @test !is_flagged(vi, vn_x, "del") - set_flag!(vi, vn_x, "del") + set_flag!!(vi, vn_x, "del") @test is_flagged(vi, vn_x, "del") - unset_flag!(vi, vn_x, "del") + unset_flag!!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!(TypedVarInfo(vi))) + test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin vi = VarInfo() @@ -133,16 +133,16 @@ gid1 = Selector() gid2 = Selector(2, :HMC) - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) + setgid!!(vi, gid2, vn) @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - vi = empty!(TypedVarInfo(vi)) + vi = empty!!(TypedVarInfo(vi)) meta = vi.metadata - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) + setgid!!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end @testset "setval! & setval_and_resample!" begin From cb1fd8bd2f6da1a776082add522e0aa4a8d76b90 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 14:51:18 +0100 Subject: [PATCH 090/107] remove SimpleVarInfo from this branch --- src/DynamicPPL.jl | 2 - src/simple_varinfo.jl | 131 ------------------------------------------ 2 files changed, 133 deletions(-) delete mode 100644 src/simple_varinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 88a5ad89f..bda0b897b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -36,7 +36,6 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, - SimpleVarInfo, push!!, empty!!, getlogp, @@ -138,7 +137,6 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") -include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 4e848e291..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,131 +0,0 @@ -""" - SimpleVarInfo{NT,T} <: AbstractVarInfo - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT <: NamedTuple`. - -## Notes -The major differences between this and `TypedVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` only supports evaluation. -""" -struct SimpleVarInfo{NT,T} <: AbstractVarInfo - θ::NT - logp::T -end - -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) - -getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) - -function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end - -function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} - # Use `getproperty` instead of `getfield` - value = getproperty(nt, sym) - return _getindex(value, inds) -end - -function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - return _getvalue(vi.θ, Val{sym}(), vn.indexing) -end -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) -# To disambiguiate. -getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) - -haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) - -istrans(::SimpleVarInfo, vn::VarName) = false - -getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ -getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ -# TODO: Should we do better? -getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ -getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) -getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) -# HACK: Need to disambiguiate. -getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) - -# Context implementations -# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. -function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) - left = vi[vn] - return left, Distributions.loglikelihood(dist, left) -end - -# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) -# end - -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::SimpleVarInfo, -) - @assert length(dist) == size(var, 1) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri - return Distributions.logpdf(dist, ri) - end - return r, lp -end - -function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi::SimpleVarInfo{<:NamedTuple}, -) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(Distributions.logpdf.(dists, r)) - return r, lp -end - -# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. -increment_num_produce!(::SimpleVarInfo) = nothing - -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end From 5936dd059f0dc5815832c2b8a6c5a421a97c994e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 15:51:54 +0100 Subject: [PATCH 091/107] added a comment --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 64a958644..278d7a5ad 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -413,6 +413,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. left .= value return value, acclogp!!(vi, logp) end From 426c465d4206e2f6e8403f6aa602262f8663c815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 15:51:59 +0100 Subject: [PATCH 092/107] reverted submodel macro to use = rather than ~ --- src/submodel_macro.jl | 10 +++++----- src/utils.jl | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 25460eada..23b6245ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,6 +1,6 @@ """ - @submodel x ~ model(args...) - @submodel prefix x ~ model(args...) + @submodel x = model(args...) + @submodel prefix x = model(args...) Treats `model` as a distribution, where `x` is the return-value of `model`. @@ -18,8 +18,8 @@ macro submodel(prefix, expr) end function submodel(expr, ctx=esc(:__context__)) - args_tilde = getargs_tilde(expr) - return if args_tilde === nothing + args_assign = getargs_assignment(expr) + return if args_assign === nothing # In this case we only want to get the `__varinfo__`. quote $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( @@ -29,7 +29,7 @@ function submodel(expr, ctx=esc(:__context__)) else # Here we also want the return-variable. # TODO: Should we prefix by `L` by default? - L, R = args_tilde + L, R = args_assign quote $(esc(L)), $(esc(:__varinfo__)) = _evaluate( $(esc(R)), $(esc(:__varinfo__)), $(ctx) diff --git a/src/utils.jl b/src/utils.jl index 14b7650fb..76efe2298 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,6 +44,21 @@ function getargs_tilde(expr::Expr) end end +""" + getargs_assignment(x) + +Return the arguments `L` and `R`, if `x` is an expression of the form `L = R`, or `nothing` +otherwise. +""" +getargs_assignment(x) = nothing +function getargs_assignment(expr::Expr) + return MacroTools.@match expr begin + (L_ = R_) => (L, R) + x_ => nothing + end +end + + ############################################ # Julia 1.2 temporary fix - Julia PR 33303 # ############################################ From a8e55bd0a6cb0798b980dbb36be9f3b263d3875d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:10:20 +0100 Subject: [PATCH 093/107] updated SimpleVarInfo impl --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a26516283..88a5ad89f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -138,6 +138,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4e848e291..501aa2185 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -21,15 +21,15 @@ SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) -function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] += logp return vi end @@ -69,8 +69,8 @@ function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple} return left, Distributions.loglikelihood(dist, left) end -# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) +# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) # end function dot_assume( From 149229f0dacd265e46691ee9c8c3f18af867712c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:21:00 +0100 Subject: [PATCH 094/107] added a couple of missing deprecations --- src/DynamicPPL.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bda0b897b..62bf73b83 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -162,6 +162,10 @@ include("submodel_macro.jl") @deprecate set_flag!(vi, vn, flag) set_flag!!(vi, vn, flag) @deprecate unset_flag!(vi, vn, flag) unset_flag!!(vi, vn, flag) +@deprecate settrans!(vi, trans, vn) settrans!!(vi, trans, vn) + +@deprecate setall!(vi, val) setall!!(vi, val) + @deprecate setgid!(vi, gid, vn) setgid!!(vi, gid, vn) @deprecate updategid!(vi, vn, spl) updategid!!(vi, vn, spl) From 809d23fbf58bbc9d08fafea9d28bccbe195afac0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:21:35 +0100 Subject: [PATCH 095/107] updated tests --- test/compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 703027fda..2b9a2273b 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -364,8 +364,8 @@ end end @model function demo_useval(x, y) - @submodel sub1 x1 ~ demo_return(x) - @submodel sub2 x2 ~ demo_return(y) + @submodel sub1 x1 = demo_return(x) + @submodel sub2 x2 = demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -399,7 +399,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - @submodel $(Symbol("ar1_$i")) x ~ AR1(num_steps, α, μ, σ) + @submodel $(Symbol("ar1_$i")) x = AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end From 07f684b0e9152172915d9dcf2226541c8ebaa5ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:47:08 +0100 Subject: [PATCH 096/107] updated implementations of logjoint and others --- src/model.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/model.jl b/src/model.jl index 929f420ff..e3c83528a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -204,8 +204,8 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, DefaultContext()) + return getlogp(varinfo_new) end """ @@ -216,8 +216,8 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, PriorContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, PriorContext()) + return getlogp(varinfo_new) end """ @@ -228,8 +228,8 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, LikelihoodContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, LikelihoodContext()) + return getlogp(varinfo_new) end """ From b00ae474c19e1e1d9c3580576e4b17db803f3174 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:48:33 +0100 Subject: [PATCH 097/107] formatting --- src/DynamicPPL.jl | 20 +++++++++++++------- src/utils.jl | 1 - 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 62bf73b83..389447344 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -27,9 +27,7 @@ import Base: keys, haskey -import BangBang: - push!!, - empty!! +import BangBang: push!!, empty!! # VarInfo export AbstractVarInfo, @@ -147,10 +145,18 @@ include("submodel_macro.jl") # Deprecations @deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector} +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) @deprecate setlogp!(vi, logp) setlogp!!(vi, logp) @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) diff --git a/src/utils.jl b/src/utils.jl index 76efe2298..de1281ac8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -58,7 +58,6 @@ function getargs_assignment(expr::Expr) end end - ############################################ # Julia 1.2 temporary fix - Julia PR 33303 # ############################################ From bfd7c789639df0395e33e0c1bf20557c27f60aa1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:42:50 +0100 Subject: [PATCH 098/107] added eltype impl for SimpleVarInfo --- src/simple_varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 501aa2185..d5ca2fc13 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -62,6 +62,11 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) # HACK: Need to disambiguiate. getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) +# Necessary for `matchingvalue` to work properly. +function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) + return T +end + # Context implementations # Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) From acb15eb9b9525eda0b57036f2b6864623d8ab1d3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:45:38 +0100 Subject: [PATCH 099/107] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d5ca2fc13..12437844e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -63,7 +63,9 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. -function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) +function Base.eltype( + vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} +) return T end From 4828aab2f3ee108f286b461a057f8909c9dfbc4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jul 2021 10:56:52 +0100 Subject: [PATCH 100/107] fixed eltype for SimpleVarInfo --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 12437844e..c88bf0192 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -65,7 +65,7 @@ getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. function Base.eltype( vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} -) +) where {T} return T end @@ -136,3 +136,5 @@ function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names return SimpleVarInfo{T}(NamedTuple{names}(vals)) end + +SimpleVarInfo(model::Model, args...) = SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) From e67ca2a44155f714edc5c806683d4aff1aef8b48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 01:54:53 +0100 Subject: [PATCH 101/107] updated to work with master --- benchmarks/utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/utils.jl b/benchmarks/utils.jl index f6c31c1c2..320a8ca7a 100644 --- a/benchmarks/utils.jl +++ b/benchmarks/utils.jl @@ -28,9 +28,10 @@ end function typed_code(m, vi = VarInfo(m)) rng = DynamicPPL.Random.MersenneTwister(42); spl = DynamicPPL.SampleFromPrior() - ctx = DynamicPPL.DefaultContext() + ctx = DynamicPPL.SamplingContext(rng, spl, DynamicPPL.DefaultContext()) - return Main.@code_typed m.f(rng, m, vi, spl, ctx, m.args...) + results = code_typed(m.f, Base.typesof(m, vi, ctx, m.args...)) + return first(results) end function make_suite(m) From 6ac63b8d97e3a6f6f42f31917109e97b3baa0087 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 02:41:18 +0100 Subject: [PATCH 102/107] changed the output structure a bit --- benchmarks/benchmark_body.jmd | 4 ++-- benchmarks/benchmarks.jmd | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index a82d9ff52..ac9953132 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -17,14 +17,14 @@ typed = typed_code(m) ```julia; echo=false; results="hidden" # Serialize the output of `typed_code` so we can compare later. -haskey(WEAVE_ARGS, :prefix) && serialize("$(WEAVE_ARGS[:prefix])_$(m.name).jls", string(typed)); +haskey(WEAVE_ARGS, :prefix) && serialize(joinpath("results", WEAVE_ARGS[:prefix],"$(m.name).jls"), string(typed)); ``` ```julia; wrap=false if haskey(WEAVE_ARGS, :prefix_old) # We want to compare the generated code to the previous version. import DiffUtils - typed_old = deserialize("$(WEAVE_ARGS[:prefix_old])_$(m.name).jls"); + typed_old = deserialize(joinpath("results", WEAVE_ARGS[:prefix_old], "$(m.name).jls")); DiffUtils.diff(typed_old, string(typed), width=130) end ``` diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index 0cc6e1de2..e28cf4695 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -10,6 +10,10 @@ using BenchmarkTools, DynamicPPL, Distributions, Serialization include("utils.jl") ``` +```julia; echo=false; results="hidden"; +mkpath(joinpath("results", WEAVE_ARGS[:prefix])) +``` + ## Models ### `demo1` From 5742f905e13a973dd1a41eab345817df44430613 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 17:17:30 +0100 Subject: [PATCH 103/107] forgot to include src --- benchmarks/src/DynamicPPLBenchmarks.jl | 115 +++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 benchmarks/src/DynamicPPLBenchmarks.jl diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl new file mode 100644 index 000000000..60a49945c --- /dev/null +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -0,0 +1,115 @@ +module DynamicPPLBenchmarks + +using DynamicPPL +using BenchmarkTools + +import Weave +import Markdown + +import LibGit2, Pkg + +export weave_benchmarks + +function time_model_def(model_def, args...) + return @time model_def(args...) +end + +function benchmark_untyped_varinfo!(suite, m) + vi = VarInfo() + # Populate. + m(vi) + # Evaluate. + suite["evaluation_untyped"] = @benchmarkable $m($vi) + return suite +end + +function benchmark_typed_varinfo!(suite, m) + # Populate. + vi = VarInfo(m) + # Evaluate. + suite["evaluation_typed"] = @benchmarkable $m($vi) + return suite +end + +function typed_code(m, vi = VarInfo(m)) + rng = DynamicPPL.Random.MersenneTwister(42); + spl = DynamicPPL.SampleFromPrior() + ctx = DynamicPPL.SamplingContext(rng, spl, DynamicPPL.DefaultContext()) + + results = code_typed(m.f, Base.typesof(m, vi, ctx, m.args...)) + return first(results) +end + +function make_suite(m) + suite = BenchmarkGroup() + benchmark_untyped_varinfo!(suite, m) + benchmark_typed_varinfo!(suite, m) + + return suite +end + +function weave_child(indoc; mod, args, kwargs...) + # FIXME: Make this work for other output formats than just `github`. + doc = Weave.WeaveDoc(indoc, nothing) + doc = Weave.run_doc(doc, doctype = "github", mod = mod, args = args, kwargs...) + rendered = Weave.render_doc(doc) + return display(Markdown.parse(rendered)) +end + +function pkgversion(m::Module) + projecttoml_path = joinpath(dirname(pathof(m)), "..", "Project.toml") + return Pkg.TOML.parsefile(projecttoml_path)["version"] +end + +function default_name(; include_commit_id=false) + dppl_path = abspath(joinpath(dirname(pathof(DynamicPPL)), "..")) + + # Extract branch name and commit id + local name + try + githead = LibGit2.head(LibGit2.GitRepo(dppl_path)) + branchname = LibGit2.shortname(githead) + + name = replace(branchname, "/" => "_") + if include_commit_id + gitcommit = LibGit2.peel(LibGit2.GitCommit, githead) + commitid = string(LibGit2.GitHash(gitcommit)) + name *= "-$(commitid)" + end + catch e + if e isa LibGit2.GitError + @info "No git repo found for $(dppl_path); extracting name from package version." + name = "release-$(pkgversion(DynamicPPL))" + else + rethrow(e) + end + end + + return name +end + +function weave_benchmarks( + ; + benchmarkbody=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd"), + include_commit_id=false, + name=default_name(include_commit_id=include_commit_id), + name_old=nothing, + include_typed_code=false, + doctype="github", + outpath="results/$(name)/", + kwargs... +) + args = Dict( + :benchmarkbody => benchmarkbody, + :name => name, + :include_typed_code => include_typed_code + ) + if !isnothing(name_old) + args[:name_old] = name_old + end + @info "Storing output in $(outpath)" + mkpath(outpath) + Weave.weave("benchmarks.jmd", doctype; out_path=outpath, args=args, kwargs...) +end + +end # module From 34cfabcc71058f1177827f60cd2ba6007d72d1dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 17:17:41 +0100 Subject: [PATCH 104/107] updated jmd files --- benchmarks/benchmark_body.jmd | 23 ++++++++++++++++------- benchmarks/benchmarks.jmd | 6 +----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index ac9953132..f9c994dc9 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -8,23 +8,32 @@ m = time_model_def(model_def, data); ```julia suite = make_suite(m); -run(suite) +results = run(suite) +results +``` + +```julia; echo=false; results="hidden"; +BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(m.name)_benchmarks.json"), results) ``` ```julia; wrap=false -typed = typed_code(m) +if WEAVE_ARGS[:include_typed_code] + typed = typed_code(m) +end ``` ```julia; echo=false; results="hidden" -# Serialize the output of `typed_code` so we can compare later. -haskey(WEAVE_ARGS, :prefix) && serialize(joinpath("results", WEAVE_ARGS[:prefix],"$(m.name).jls"), string(typed)); +if WEAVE_ARGS[:include_typed_code] + # Serialize the output of `typed_code` so we can compare later. + haskey(WEAVE_ARGS, :name) && serialize(joinpath("results", WEAVE_ARGS[:name],"$(m.name).jls"), string(typed)); +end ``` -```julia; wrap=false -if haskey(WEAVE_ARGS, :prefix_old) +```julia; wrap=false; echo=false; +if haskey(WEAVE_ARGS, :name_old) # We want to compare the generated code to the previous version. import DiffUtils - typed_old = deserialize(joinpath("results", WEAVE_ARGS[:prefix_old], "$(m.name).jls")); + typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(m.name).jls")); DiffUtils.diff(typed_old, string(typed), width=130) end ``` diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index e28cf4695..614afb2e9 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -7,11 +7,7 @@ using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia -include("utils.jl") -``` - -```julia; echo=false; results="hidden"; -mkpath(joinpath("results", WEAVE_ARGS[:prefix])) +import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` ## Models From abb1768ec4415fb9826ef601fce09edc8a1b31f2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 17:30:13 +0100 Subject: [PATCH 105/107] added some docs --- benchmarks/src/DynamicPPLBenchmarks.jl | 63 ++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 60a49945c..8ba5eb838 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -40,14 +40,35 @@ function typed_code(m, vi = VarInfo(m)) return first(results) end -function make_suite(m) +""" + make_suite(model) + +Create default benchmark suite for `model`. +""" +function make_suite(model) suite = BenchmarkGroup() - benchmark_untyped_varinfo!(suite, m) - benchmark_typed_varinfo!(suite, m) + benchmark_untyped_varinfo!(suite, model) + benchmark_typed_varinfo!(suite, model) return suite end +""" + weave_child(indoc; mod, args, kwargs...) + +Weave `indoc` with scope of `mod` into markdown. + +Useful for weaving within weaving, e.g. +```julia +weave_child(child_jmd_path, mod = @__MODULE__, args = WEAVE_ARGS) +``` +together with `results="markup"` and `echo=false` will simply insert +the weaved version of `indoc`. + +# Notes +- Currently only supports `doctype == "github"`. Other outputs are "supported" + in the sense that it works but you might lose niceties such as syntax highlighting. +""" function weave_child(indoc; mod, args, kwargs...) # FIXME: Make this work for other output formats than just `github`. doc = Weave.WeaveDoc(indoc, nothing) @@ -56,11 +77,28 @@ function weave_child(indoc; mod, args, kwargs...) return display(Markdown.parse(rendered)) end +""" + pkgversion(m::Module) + +Return version of module `m` as listed in its Project.toml. +""" function pkgversion(m::Module) projecttoml_path = joinpath(dirname(pathof(m)), "..", "Project.toml") return Pkg.TOML.parsefile(projecttoml_path)["version"] end +""" + default_name(; include_commit_id=false) + +Construct a name from either repo information or package version +of `DynamicPPL`. + +If the path of `DynamicPPL` is a git-repo, return name of current branch, +joined with the commit id if `include_commit_id` is `true`. + +If path of `DynamicPPL` is _not_ a git-repo, it is assumed to be a release, +resulting in a name of the form `release-VERSION`. +""" function default_name(; include_commit_id=false) dppl_path = abspath(joinpath(dirname(pathof(DynamicPPL)), "..")) @@ -88,8 +126,23 @@ function default_name(; include_commit_id=false) return name end +""" + weave_benchmarks(input="benchmarks.jmd"; kwargs...) + +Weave benchmarks present in `benchmarks.jmd` into a single file. + +# Keyword arguments +- `benchmarkbody`: JMD-file to be rendered for each model. +- `include_commit_id=false`: specify whether to include commit-id in the default name. +- `name`: the name of directory in `results/` to use as output directory. +- `name_old=nothing`: if specified, comparisons of current run vs. the run pinted to + by `name_old` will be included in the generated document. +- `include_typed_code=false`: if `true`, output of `code_typed` for the evaluator + of the model will be included in the weaved document. +- Rest of the passed `kwargs` will be passed on to `Weave.weave`. +""" function weave_benchmarks( - ; + input="benchmarks.jmd"; benchmarkbody=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd"), include_commit_id=false, name=default_name(include_commit_id=include_commit_id), @@ -109,7 +162,7 @@ function weave_benchmarks( end @info "Storing output in $(outpath)" mkpath(outpath) - Weave.weave("benchmarks.jmd", doctype; out_path=outpath, args=args, kwargs...) + Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...) end end # module From 4ea7bfc8fa43473c18be5b5b22a90b7fd6b68f1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 17:32:38 +0100 Subject: [PATCH 106/107] updated README --- benchmarks/README.md | 30 ++++++++++++++++++++++---- benchmarks/src/DynamicPPLBenchmarks.jl | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 565217753..a377d74f3 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,8 +1,30 @@ To run the benchmarks, simply do: ```sh -julia --project -e 'using Weave; Weave.weave("benchmarks.jmd", doctype="github", args=Dict(:benchmarkbody => "benchmark_body.jmd"));' +julia --project -e 'using DynamicPPLBenchmarks; weave_benchmarks();' ``` -Furthermore: -- If you want to save the output of `code_typed` for the evaluator of the different models, add a `:prefix => "myprefix"` to the `args`. -- If `:prefix_old` is specified in `args`, a `diff` of the `code_typed` loaded using `:prefix_old` and the output of `code_typed` for the current run will be included in the weaved document. +```julia +help?> weave_benchmarks +search: weave_benchmarks + + weave_benchmarks(input="benchmarks.jmd"; kwargs...) + + Weave benchmarks present in benchmarks.jmd into a single file. + + Keyword arguments + ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡ + + • benchmarkbody: JMD-file to be rendered for each model. + + • include_commit_id=false: specify whether to include commit-id in the default name. + + • name: the name of directory in results/ to use as output directory. + + • name_old=nothing: if specified, comparisons of current run vs. the run pinted to by name_old + will be included in the generated document. + + • include_typed_code=false: if true, output of code_typed for the evaluator of the model will be + included in the weaved document. + + • Rest of the passed kwargs will be passed on to Weave.weave. +``` diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 8ba5eb838..a6888dfd2 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -142,7 +142,7 @@ Weave benchmarks present in `benchmarks.jmd` into a single file. - Rest of the passed `kwargs` will be passed on to `Weave.weave`. """ function weave_benchmarks( - input="benchmarks.jmd"; + input=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmarks.jmd"); benchmarkbody=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd"), include_commit_id=false, name=default_name(include_commit_id=include_commit_id), From 1147f640a4d5c0b196575398030fa7fdceff88c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 8 Jul 2021 17:33:30 +0100 Subject: [PATCH 107/107] formatting --- benchmarks/src/DynamicPPLBenchmarks.jl | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index a6888dfd2..3f0e28ad7 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -3,10 +3,11 @@ module DynamicPPLBenchmarks using DynamicPPL using BenchmarkTools -import Weave -import Markdown +using Weave: Weave +using Markdown: Markdown -import LibGit2, Pkg +using LibGit2: LibGit2 +using Pkg: Pkg export weave_benchmarks @@ -31,8 +32,8 @@ function benchmark_typed_varinfo!(suite, m) return suite end -function typed_code(m, vi = VarInfo(m)) - rng = DynamicPPL.Random.MersenneTwister(42); +function typed_code(m, vi=VarInfo(m)) + rng = DynamicPPL.Random.MersenneTwister(42) spl = DynamicPPL.SampleFromPrior() ctx = DynamicPPL.SamplingContext(rng, spl, DynamicPPL.DefaultContext()) @@ -72,7 +73,7 @@ the weaved version of `indoc`. function weave_child(indoc; mod, args, kwargs...) # FIXME: Make this work for other output formats than just `github`. doc = Weave.WeaveDoc(indoc, nothing) - doc = Weave.run_doc(doc, doctype = "github", mod = mod, args = args, kwargs...) + doc = Weave.run_doc(doc; doctype="github", mod=mod, args=args, kwargs...) rendered = Weave.render_doc(doc) return display(Markdown.parse(rendered)) end @@ -143,26 +144,28 @@ Weave benchmarks present in `benchmarks.jmd` into a single file. """ function weave_benchmarks( input=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmarks.jmd"); - benchmarkbody=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd"), + benchmarkbody=joinpath( + dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd" + ), include_commit_id=false, - name=default_name(include_commit_id=include_commit_id), + name=default_name(; include_commit_id=include_commit_id), name_old=nothing, include_typed_code=false, doctype="github", outpath="results/$(name)/", - kwargs... + kwargs..., ) args = Dict( :benchmarkbody => benchmarkbody, :name => name, - :include_typed_code => include_typed_code + :include_typed_code => include_typed_code, ) if !isnothing(name_old) args[:name_old] = name_old end @info "Storing output in $(outpath)" mkpath(outpath) - Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...) + return Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...) end end # module