diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eaf30feed..015317798 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -66,6 +66,7 @@ export AbstractVarInfo, Model, getmissings, getargnames, + generated_quantities, # Samplers Sampler, SampleFromPrior, diff --git a/src/model.jl b/src/model.jl index 739772b02..0a2b00606 100644 --- a/src/model.jl +++ b/src/model.jl @@ -200,3 +200,71 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) model(varinfo, SampleFromPrior(), LikelihoodContext()) return getlogp(varinfo) end + +""" + generated_quantities(model::Model, chain::AbstractChains) + +Execute `model` for each of the samples in `chain` and return an array of the values +returned by the `model` for each sample. + +# Examples +## General +Often you might have additional quantities computed inside the model that you want to +inspect, e.g. +```julia +@model function demo(x) + # sample and observe + θ ~ Prior() + x ~ Likelihood() + return interesting_quantity(θ, x) +end +m = demo(data) +chain = sample(m, alg, n) +# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples +# from the posterior/`chain`: +generated_quantities(m, chain) # <= results in a `Vector` of returned values + # from `interesting_quantity(θ, x)` +``` +## Concrete (and simple) +```julia +julia> using DynamicPPL, Turing + +julia> @model function demo(xs) + s ~ InverseGamma(2, 3) + m_shifted ~ Normal(10, √s) + m = m_shifted - 10 + + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + return (m, ) + end +demo (generic function with 1 method) + +julia> model = demo(randn(10)); + +julia> chain = sample(model, MH(), 10); + +julia> generated_quantities(model, chain) +10×1 Array{Tuple{Float64},2}: + (2.1964758025119338,) + (2.1964758025119338,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.043088571494005024,) + (-0.16489786710222099,) + (-0.16489786710222099,) +``` +""" +function generated_quantities(model::Model, chain::AbstractChains) + varinfo = VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + return map(iters) do (sample_idx, chain_idx) + setval!(varinfo, chain, sample_idx, chain_idx) + model(varinfo) + end +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 2d2351cea..1ccdac660 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1178,9 +1178,12 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value end function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) - sym = Symbol(vn) - regex = Regex("^$sym\$|^$sym\\[") - indices = findall(x -> match(regex, string(x)) !== nothing, keys) + string_vn = string(vn) + string_vn_indexing = string_vn * "[" + indices = findall(keys) do x + string_x = string(x) + return string_x == string_vn || startswith(string_x, string_vn_indexing) + end if !isempty(indices) sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural) val = mapreduce(vcat, sorted_indices) do i diff --git a/test/model.jl b/test/model.jl index c00b2a95a..83899769a 100644 --- a/test/model.jl +++ b/test/model.jl @@ -44,4 +44,98 @@ Random.seed!(1234) end end end + + @testset "setval! & generated_quantities" begin + @model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + for i in 1:2 + m[i] ~ Normal(0, 1) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.) + end + + return (m, ) + end + + @model function demo2(xs) + m ~ MvNormal(2, 1.) + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.) + end + + return (m, ) + end + + xs = randn(3); + model1 = demo1(xs); + model2 = demo2(xs); + + chain1 = sample(model1, MH(), 100); + chain2 = sample(model2, MH(), 100); + + res11 = generated_quantities(model1, chain1) + res21 = generated_quantities(model2, chain1) + + res12 = generated_quantities(model1, chain2) + res22 = generated_quantities(model2, chain2) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res11 .== res21) + @test all(res12 .== res22) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res12[1:end - 1] .!= res12[2:end]) + + test_setval!(model1, chain1) + test_setval!(model2, chain2) + + # Next level + @model function demo3(xs, ::Type{TV} = Vector{Float64}) where {TV} + m = Vector{TV}(undef, 2) + for i = 1:length(m) + m[i] ~ MvNormal(2, 1.) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.) + end + + return (m, ) + end + + @model function demo4(xs, ::Type{TV} = Vector{Vector{Float64}}) where {TV} + m = TV(undef, 2) + for i = 1:length(m) + m[i] ~ MvNormal(2, 1.) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.) + end + + return (m, ) + end + + model3 = demo3(xs); + model4 = demo4(xs); + + chain3 = sample(model3, MH(), 100); + chain4 = sample(model4, MH(), 100); + + res33 = generated_quantities(model3, chain3) + res43 = generated_quantities(model4, chain3) + + res34 = generated_quantities(model3, chain4) + res44 = generated_quantities(model4, chain4) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res33 .== res43) + @test all(res34 .== res44) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res34[1:end - 1] .!= res34[2:end]) + end end diff --git a/test/test_util.jl b/test/test_util.jl index d8a926656..8bfad4d3d 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -36,3 +36,33 @@ function test_model_ad(model, logp_manual) @test y ≈ lp @test back(1)[1] ≈ grad end + + +""" + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + +Test `setval!` on `model` and `chain`. + +Worth noting that this only supports models containing symbols of the forms +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. +""" +function test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + nt = DynamicPPL.tonamedtuple(var_info) + for (k, (vals, names)) in pairs(nt) + for (n, v) in zip(names, vals) + chain_val = if Symbol(n) ∉ keys(chain) + # Assume it's a group + vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + else + chain[sample_idx, n, chain_idx] + end + @test v == chain_val + end + end +end