Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.9.1"
version = "0.9.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export AbstractVarInfo,
# Convenience functions
logprior,
logjoint,
elementwise_loglikelihoods,
# Convenience macros
@addlogprob!

Expand Down Expand Up @@ -118,5 +119,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")

end # module
131 changes: 131 additions & 0 deletions src/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Context version
struct ElementwiseLikelihoodContext{A, Ctx} <: AbstractContext
loglikelihoods::A
ctx::Ctx
end

function ElementwiseLikelihoodContext(
likelihoods = Dict{VarName, Vector{Float64}}(),
ctx::AbstractContext = LikelihoodContext()
)
return ElementwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx)
end

function Base.push!(
ctx::ElementwiseLikelihoodContext{Dict{VarName, Vector{Float64}}},
vn::VarName,
logp::Real
)
lookup = ctx.loglikelihoods
ℓ = get!(lookup, vn, Float64[])
push!(ℓ, logp)
end

function Base.push!(
ctx::ElementwiseLikelihoodContext{Dict{VarName, Float64}},
vn::VarName,
logp::Real
)
ctx.loglikelihoods[vn] = logp
end


function tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
end

function dot_tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, left, vn, inds, vi)
value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi)
acclogp!(vi, logp)
return value
end


function tilde_observe(ctx::ElementwiseLikelihoodContext, sampler, right, left, vname, vinds, vi)
# This is slightly unfortunate since it is not completely generic...
# Ideally we would call `tilde_observe` recursively but then we don't get the
# loglikelihood value.
logp = tilde(ctx.ctx, sampler, right, left, vi)
acclogp!(vi, logp)

# track loglikelihood value
push!(ctx, vname, logp)

return left
end


"""
elementwise_loglikelihoods(model::Model, chain::Chains)

Runs `model` on each sample in `chain` returning an array of arrays with
the i-th element inner arrays corresponding to the the likelihood of the i-th
observation for that particular sample in `chain`.

# Notes
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
*observation*) statements can be implemented in two ways:
```julia
for i in eachindex(y)
y[i] ~ Normal(μ, σ)
end
```
or
```julia
y ~ MvNormal(fill(μ, n), fill(σ, n))
```
Unfortunately, just by looking at the latter statement, it's impossible to tell whether or
not this is one *single* observation which is `n` dimensional OR if we have *multiple*
1-dimensional observations. Therefore, `loglikelihoods` will only work with the first
example.

# Examples
```julia-repl
julia> using DynamicPPL, Turing

julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

y ~ Normal(m, √s)
end
demo (generic function with 1 method)

julia> model = demo(randn(3), randn());

julia> chain = sample(model, MH(), 10);

julia> DynamicPPL.elementwise_loglikelihoods(model, chain)
Dict{String,Array{Float64,1}} with 4 entries:
"xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716]
"xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839]
"xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025]
"y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267]
```
"""
function elementwise_loglikelihoods(model::Model, chain)
# Get the data by executing the model once
ctx = ElementwiseLikelihoodContext()
spl = SampleFromPrior()
vi = VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
Comment on lines +118 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here empty! is not needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, empty! will ruin it! I did that initially, but empty! also made it so that values would be resampled. So it ended up sampling from the prior instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but why would it resample values in this case? Shouldn't setval! fix them? There's something going on with this empty!/setval! thing that I don't understand 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, okay then there's something weird. I thought I just had misunderstood something, but if you also don't know why that's the case then there's something going on 😅

Can it be the fact that empty! clears the "del" flag + setval! does NOT set it to false? So then we you run the model again, it will resample?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, something is wrong:

julia> using DynamicPPL, Turing

julia> import DynamicPPL: setval!

julia> @model function demo(xs)
           m ~ MvNormal(2, 1.)
           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end
       end
demo (generic function with 1 method)

julia> model = demo(randn(3));

julia> chain = sample(model, MH(), 10);

julia> var_info = VarInfo(model);

julia> θ_old = var_info[SampleFromPrior()]
2-element Array{Float64,1}:
 0.41438744434831887
 0.373145757716783

julia> θ_chain = vec(MCMCChains.group(chain, :m)[1, :, 1].value)
2-element reshape(::AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}}, 2) with eltype Float64:
  0.3939341450590006
 -1.1020030439893758

julia> empty!(var_info)
VarInfo{NamedTuple{(:m,),Tuple{DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}((m = DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}(Dict{VarName{:m,Tuple{}},Int64}(), VarName{:m,Tuple{}}[], UnitRange{Int64}[], Float64[], MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}}[], Set{DynamicPPL.Selector}[], Int64[], Dict{String,BitArray{1}}("del" => [],"trans" => [])),), Base.RefValue{Float64}(0.0), Base.RefValue{Int64}(0))

julia> setval!(var_info, chain, 1, 1)
VarInfo{NamedTuple{(:m,),Tuple{DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}((m = DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}(Dict{VarName{:m,Tuple{}},Int64}(), VarName{:m,Tuple{}}[], UnitRange{Int64}[], Float64[], MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}}[], Set{DynamicPPL.Selector}[], Int64[], Dict{String,BitArray{1}}("del" => [],"trans" => [])),), Base.RefValue{Float64}(0.0), Base.RefValue{Int64}(0))

julia> θ_new = var_info[SampleFromPrior()]
Float64[]

julia> model(var_info)

julia> var_info[SampleFromPrior()]
2-element Array{Float64,1}:
 -0.7971105060749638
0.9046609763240063

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured it out: empty! clears vi.metadata.$n.vns for every $n in names, so the following never touches _setval_kernel!:

DynamicPPL.jl/src/varinfo.jl

Lines 1167 to 1171 in 334cb98

quote
for vn in metadata.$n.vns
_setval_kernel!(vi, vn, values, keys)
end
end


# Execute model
model(vi, spl, ctx)
end
return ctx.loglikelihoods
end

function elementwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
ctx = ElementwiseLikelihoodContext(Dict{VarName, Float64}())
model(varinfo, SampleFromPrior(), ctx)
return ctx.loglikelihoods
end
40 changes: 40 additions & 0 deletions test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using .Turing

@testset "loglikelihoods" begin
@model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

y ~ Normal(m, √s)
end

xs = randn(3);
y = randn();
model = demo(xs, y);
chain = sample(model, MH(), 100);
results = elementwise_loglikelihoods(model, chain)
var_to_likelihoods = Dict(string(varname) => logliks for (varname, logliks) in results)
@test haskey(var_to_likelihoods, "xs[1]")
@test haskey(var_to_likelihoods, "xs[2]")
@test haskey(var_to_likelihoods, "xs[3]")
@test haskey(var_to_likelihoods, "y")

for (i, (s, m)) in enumerate(zip(chain[:s], chain[:m]))
@test logpdf(Normal(m, √s), xs[1]) == var_to_likelihoods["xs[1]"][i]
@test logpdf(Normal(m, √s), xs[2]) == var_to_likelihoods["xs[2]"][i]
@test logpdf(Normal(m, √s), xs[3]) == var_to_likelihoods["xs[3]"][i]
@test logpdf(Normal(m, √s), y) == var_to_likelihoods["y"][i]
end

var_info = VarInfo(model)
results = DynamicPPL.elementwise_loglikelihoods(model, var_info)
var_to_likelihoods = Dict(string(vn) => ℓ for (vn, ℓ) in results)
s, m = var_info[SampleFromPrior()]
@test logpdf(Normal(m, √s), xs[1]) == var_to_likelihoods["xs[1]"]
@test logpdf(Normal(m, √s), xs[2]) == var_to_likelihoods["xs[2]"]
@test logpdf(Normal(m, √s), xs[3]) == var_to_likelihoods["xs[3]"]
@test logpdf(Normal(m, √s), y) == var_to_likelihoods["y"]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include("test_util.jl")
include("independence.jl")
include("distribution_wrappers.jl")
include("context_implementations.jl")
include("loglikelihoods.jl")

include("threadsafe.jl")

Expand Down