From a64b31d93dab8a2cefe7d185c417179ba6fea237 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Sep 2020 01:20:01 +0100 Subject: [PATCH 1/4] added generated_quantities and some tests --- src/DynamicPPL.jl | 1 + src/model.jl | 69 +++++++++++++++++++++++++ src/varinfo.jl | 9 ++-- test/model.jl | 125 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eaf30feed..015317798 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -66,6 +66,7 @@ export AbstractVarInfo, Model, getmissings, getargnames, + generated_quantities, # Samplers Sampler, SampleFromPrior, diff --git a/src/model.jl b/src/model.jl index 739772b02..3118b79ee 100644 --- a/src/model.jl +++ b/src/model.jl @@ -200,3 +200,72 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) model(varinfo, SampleFromPrior(), LikelihoodContext()) return getlogp(varinfo) end + +""" + generated_quantities(model::Turing.Model, chain::AbstractChains) + +Executes `model` for each of the samples in `chain` and returns an array of the values +returned by the `model` for each sample. + +# Examples +## General +Often you might have additional quantities computed inside the model that you want to +inspect, e.g. +```julia +@model function demo(x) + # sample and observe + θ ~ Prior() + x ~ Likelihood() + return interesting_quantity(θ, x) +end +m = demo(data) +chain = sample(m, alg, n) +# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples +# from the posterior/`chain`: +generated_quantities(m, chain) # <= results in a `Vector` of returned values + # from `interesting_quantity(θ, x)` +``` +## Concrete (and simple) +```julia +julia> using DynamicPPL, Turing + +julia> @model function demo(xs) + s ~ InverseGamma(2, 3) + m_shifted ~ Normal(10, √s) + m = m_shifted - 10 + + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + return (m, ) + end +demo (generic function with 1 method) + +julia> model = demo(randn(10)); + +julia> chain = sample(model, MH(), 10); + +julia> generated_quantities(model, chain) +10×1 Array{Tuple{Float64},2}: + (2.1964758025119338,) + (2.1964758025119338,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.043088571494005024,) + (-0.16489786710222099,) + (-0.16489786710222099,) +``` +""" +function generated_quantities(model::Model, chain::AbstractChains) + varinfo = VarInfo(model) + spl = SampleFromPrior() + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + return map(iters) do (sample_idx, chain_idx) + setval!(varinfo, chain, sample_idx, chain_idx) + model(varinfo, spl) + end +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 2d2351cea..1ccdac660 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1178,9 +1178,12 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value end function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) - sym = Symbol(vn) - regex = Regex("^$sym\$|^$sym\\[") - indices = findall(x -> match(regex, string(x)) !== nothing, keys) + string_vn = string(vn) + string_vn_indexing = string_vn * "[" + indices = findall(keys) do x + string_x = string(x) + return string_x == string_vn || startswith(string_x, string_vn_indexing) + end if !isempty(indices) sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural) val = mapreduce(vcat, sorted_indices) do i diff --git a/test/model.jl b/test/model.jl index c00b2a95a..43f3e626b 100644 --- a/test/model.jl +++ b/test/model.jl @@ -1,5 +1,36 @@ Random.seed!(1234) +using Test + +""" + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + +Test `setval!` on `model` and `chain`. + +Worth noting that this only supports models containing symbols of the forms +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. +""" +function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + var_info = DynamicPPL.VarInfo(model) + spl = DynamicPPL.SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + nt = DynamicPPL.tonamedtuple(var_info) + for (k, (vals, names)) in pairs(nt) + for (n, v) in zip(names, vals) + chain_val = if Symbol(n) ∉ MCMCChains.keys(chain) + # Assume it's a group + vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + else + chain[sample_idx, n, chain_idx] + end + @test v == chain_val + end + end +end + @testset "model.jl" begin @testset "convenience functions" begin model = gdemo_default @@ -44,4 +75,98 @@ Random.seed!(1234) end end end + + @testset "setval! & generated_quantities" begin + @model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + for i in 1:2 + m[i] ~ Normal(0, 1) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.) + end + + return (m, ) + end + + @model function demo2(xs) + m ~ MvNormal(2, 1.) + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.) + end + + return (m, ) + end + + xs = randn(3); + model1 = demo1(xs); + model2 = demo2(xs); + + chain1 = sample(model1, MH(), 100); + chain2 = sample(model2, MH(), 100); + + res11 = generated_quantities(model1, chain1) + res21 = generated_quantities(model2, chain1) + + res12 = generated_quantities(model1, chain2) + res22 = generated_quantities(model2, chain2) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res11 .== res21) + @test all(res12 .== res22) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res12[1:end - 1] .!= res12[2:end]) + + test_setval!(model1, chain1) + test_setval!(model2, chain2) + + # Next level + @model function demo3(xs, ::Type{TV} = Vector{Float64}) where {TV} + m = Vector{TV}(undef, 2) + for i = 1:length(m) + m[i] ~ MvNormal(2, 1.) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.) + end + + return (m, ) + end + + @model function demo4(xs, ::Type{TV} = Vector{Vector{Float64}}) where {TV} + m = TV(undef, 2) + for i = 1:length(m) + m[i] ~ MvNormal(2, 1.) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.) + end + + return (m, ) + end + + model3 = demo3(xs); + model4 = demo4(xs); + + chain3 = sample(model3, MH(), 100); + chain4 = sample(model4, MH(), 100); + + res33 = generated_quantities(model3, chain3) + res43 = generated_quantities(model4, chain3) + + res34 = generated_quantities(model3, chain4) + res44 = generated_quantities(model4, chain4) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res33 .== res43) + @test all(res34 .== res44) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res34[1:end - 1] .!= res34[2:end]) + end end From 49e4e480ccfb93112473aa3d4aa6722ca421078c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Sep 2020 01:58:00 +0100 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: David Widmann --- src/model.jl | 7 +++---- test/model.jl | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/model.jl b/src/model.jl index 3118b79ee..0a2b00606 100644 --- a/src/model.jl +++ b/src/model.jl @@ -202,9 +202,9 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) end """ - generated_quantities(model::Turing.Model, chain::AbstractChains) + generated_quantities(model::Model, chain::AbstractChains) -Executes `model` for each of the samples in `chain` and returns an array of the values +Execute `model` for each of the samples in `chain` and return an array of the values returned by the `model` for each sample. # Examples @@ -262,10 +262,9 @@ julia> generated_quantities(model, chain) """ function generated_quantities(model::Model, chain::AbstractChains) varinfo = VarInfo(model) - spl = SampleFromPrior() iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) setval!(varinfo, chain, sample_idx, chain_idx) - model(varinfo, spl) + model(varinfo) end end diff --git a/test/model.jl b/test/model.jl index 43f3e626b..ff5ded3f8 100644 --- a/test/model.jl +++ b/test/model.jl @@ -11,8 +11,8 @@ Worth noting that this only supports models containing symbols of the forms `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. """ function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - var_info = DynamicPPL.VarInfo(model) - spl = DynamicPPL.SampleFromPrior() + var_info = VarInfo(model) + spl = SampleFromPrior() θ_old = var_info[spl] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) θ_new = var_info[spl] @@ -20,7 +20,7 @@ function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) nt = DynamicPPL.tonamedtuple(var_info) for (k, (vals, names)) in pairs(nt) for (n, v) in zip(names, vals) - chain_val = if Symbol(n) ∉ MCMCChains.keys(chain) + chain_val = if Symbol(n) ∉ keys(chain) # Assume it's a group vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) else From a9f31dc95277a875427169f7f5b59402c160c52b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Sep 2020 02:00:40 +0100 Subject: [PATCH 3/4] removed redundant using statement in test --- test/model.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/model.jl b/test/model.jl index ff5ded3f8..767304ca3 100644 --- a/test/model.jl +++ b/test/model.jl @@ -1,7 +1,5 @@ Random.seed!(1234) -using Test - """ test_setval!(model, chain; sample_idx = 1, chain_idx = 1) From c5784966bf66b186a194b5e4ab9baa341b96c1dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Sep 2020 02:05:26 +0100 Subject: [PATCH 4/4] moved test_setval! --- test/model.jl | 29 ----------------------------- test/test_util.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/test/model.jl b/test/model.jl index 767304ca3..83899769a 100644 --- a/test/model.jl +++ b/test/model.jl @@ -1,34 +1,5 @@ Random.seed!(1234) -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - nt = DynamicPPL.tonamedtuple(var_info) - for (k, (vals, names)) in pairs(nt) - for (n, v) in zip(names, vals) - chain_val = if Symbol(n) ∉ keys(chain) - # Assume it's a group - vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) - else - chain[sample_idx, n, chain_idx] - end - @test v == chain_val - end - end -end - @testset "model.jl" begin @testset "convenience functions" begin model = gdemo_default diff --git a/test/test_util.jl b/test/test_util.jl index d8a926656..8bfad4d3d 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -36,3 +36,33 @@ function test_model_ad(model, logp_manual) @test y ≈ lp @test back(1)[1] ≈ grad end + + +""" + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + +Test `setval!` on `model` and `chain`. + +Worth noting that this only supports models containing symbols of the forms +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. +""" +function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + nt = DynamicPPL.tonamedtuple(var_info) + for (k, (vals, names)) in pairs(nt) + for (n, v) in zip(names, vals) + chain_val = if Symbol(n) ∉ keys(chain) + # Assume it's a group + vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + else + chain[sample_idx, n, chain_idx] + end + @test v == chain_val + end + end +end