Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
generated_quantities,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
68 changes: 68 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,71 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
model(varinfo, SampleFromPrior(), LikelihoodContext())
return getlogp(varinfo)
end

"""
generated_quantities(model::Model, chain::AbstractChains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.

# Examples
## General
Often you might have additional quantities computed inside the model that you want to
inspect, e.g.
```julia
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10

for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

return (m, )
end
demo (generic function with 1 method)

julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
```
"""
function generated_quantities(model::Model, chain::AbstractChains)
varinfo = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval!(varinfo, chain, sample_idx, chain_idx)
model(varinfo)
end
end
9 changes: 6 additions & 3 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1178,9 +1178,12 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
end

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
sym = Symbol(vn)
regex = Regex("^$sym\$|^$sym\\[")
indices = findall(x -> match(regex, string(x)) !== nothing, keys)
string_vn = string(vn)
string_vn_indexing = string_vn * "["
indices = findall(keys) do x
string_x = string(x)
return string_x == string_vn || startswith(string_x, string_vn_indexing)
end
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
Expand Down
94 changes: 94 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,98 @@ Random.seed!(1234)
end
end
end

@testset "setval! & generated_quantities" begin
@model function demo1(xs, ::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, 2)
for i in 1:2
m[i] ~ Normal(0, 1)
end

for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.)
end

return (m, )
end

@model function demo2(xs)
m ~ MvNormal(2, 1.)

for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.)
end

return (m, )
end

xs = randn(3);
model1 = demo1(xs);
model2 = demo2(xs);

chain1 = sample(model1, MH(), 100);
chain2 = sample(model2, MH(), 100);

res11 = generated_quantities(model1, chain1)
res21 = generated_quantities(model2, chain1)

res12 = generated_quantities(model1, chain2)
res22 = generated_quantities(model2, chain2)

# Check that the two different models produce the same values for
# the same chains.
@test all(res11 .== res21)
@test all(res12 .== res22)
# Ensure that they're not all the same (some can be, because rejected samples)
@test any(res12[1:end - 1] .!= res12[2:end])

test_setval!(model1, chain1)
test_setval!(model2, chain2)

# Next level
@model function demo3(xs, ::Type{TV} = Vector{Float64}) where {TV}
m = Vector{TV}(undef, 2)
for i = 1:length(m)
m[i] ~ MvNormal(2, 1.)
end

for i in eachindex(xs)
xs[i] ~ Normal(m[1][1], 1.)
end

return (m, )
end

@model function demo4(xs, ::Type{TV} = Vector{Vector{Float64}}) where {TV}
m = TV(undef, 2)
for i = 1:length(m)
m[i] ~ MvNormal(2, 1.)
end

for i in eachindex(xs)
xs[i] ~ Normal(m[1][1], 1.)
end

return (m, )
end

model3 = demo3(xs);
model4 = demo4(xs);

chain3 = sample(model3, MH(), 100);
chain4 = sample(model4, MH(), 100);

res33 = generated_quantities(model3, chain3)
res43 = generated_quantities(model4, chain3)

res34 = generated_quantities(model3, chain4)
res44 = generated_quantities(model4, chain4)

# Check that the two different models produce the same values for
# the same chains.
@test all(res33 .== res43)
@test all(res34 .== res44)
# Ensure that they're not all the same (some can be, because rejected samples)
@test any(res34[1:end - 1] .!= res34[2:end])
end
end
30 changes: 30 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,33 @@ function test_model_ad(model, logp_manual)
@test y ≈ lp
@test back(1)[1] ≈ grad
end


"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)

Test `setval!` on `model` and `chain`.

Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
nt = DynamicPPL.tonamedtuple(var_info)
for (k, (vals, names)) in pairs(nt)
for (n, v) in zip(names, vals)
chain_val = if Symbol(n) ∉ keys(chain)
# Assume it's a group
vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx])
else
chain[sample_idx, n, chain_idx]
end
@test v == chain_val
end
end
end