Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ export @model,
current_trace,
getweights,
effectiveSampleSize,
increase_logweight,
inrease_logevidence,
increase_logweight!,
propagate!,
resample!,
ResampleWithESSThreshold,
ADBackend,
Expand Down
62 changes: 28 additions & 34 deletions src/core/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,10 @@ mutable struct ParticleContainer{T<:Particle, F}
vals::Vector{T}
# logarithmic weights (Trace) or incremental log-likelihoods (ParticleContainer)
logWs::Vector{Float64}
# log model evidence
logE::Float64
# helpful for rejuvenation steps, e.g. in SMC2
n_consume::Int
end

ParticleContainer(model, particles::Vector{<:Particle}) =
ParticleContainer(model, particles, zeros(length(particles)), 0.0, 0)
ParticleContainer(model, particles, zeros(length(particles)))

Base.collect(pc::ParticleContainer) = pc.vals
Base.length(pc::ParticleContainer) = length(pc.vals)
Expand All @@ -105,66 +101,64 @@ function Base.copy(pc::ParticleContainer)
# copy weights
logWs = copy(pc.logWs)

ParticleContainer(pc.model, vals, logWs, pc.logE, pc.n_consume)
ParticleContainer(pc.model, vals, logWs)
end

# run particle filter for one step, return incremental likelihood
function Libtask.consume(pc :: ParticleContainer)
"""
propagate!(pc::ParticleContainer)

Run particle filter for one step and check if the final time step is reached.
"""
function propagate!(pc::ParticleContainer)
# normalisation factor: 1/N
z1 = logZ(pc)
n = length(pc)

particles = collect(pc)
num_done = 0
for i=1:n
numdone = 0
for i in 1:n
p = particles[i]
score = Libtask.consume(p)
if score isa Real
score += getlogp(p.vi)
resetlogp!(p.vi)
increase_logweight(pc, i, Float64(score))
increase_logweight!(pc, i, Float64(score))
elseif score == Val{:done}
num_done += 1
numdone += 1
else
error("[consume]: error in running particle filter.")
end
end

if num_done == n
res = Val{:done}
elseif num_done != 0
# The posterior for models with random number of observations is not well-defined.
error("[consume]: mis-aligned execution traces, num_particles= $(n),
num_done=$(num_done). Please make sure the number of observations is NOT random.")
else
# update incremental likelihoods
z2 = logZ(pc)
res = increase_logevidence(pc, z2 - z1)
pc.n_consume += 1
# res = increase_loglikelihood(pc, z2 - z1)
# Check if all particles are propagated to the final time point.
numdone == n && return true

# The posterior for models with random number of observations is not well-defined.
if numdone != 0
error("mis-aligned execution traces: # particles = ", n,
" # completed trajectories = ", numdone,
". Please make sure the number of observations is NOT random.")
end

res
return false
end

# compute the normalized weights
getweights(pc::ParticleContainer) = softmax(pc.logWs)

# compute the log-likelihood estimate, ignoring constant term ``- \log num_particles``
logZ(pc::ParticleContainer) = logsumexp(pc.logWs)
"""
logZ(pc::ParticleContainer)

Return the estimate of the log-likelihood ``p(y_t | y_{1:(t-1)}, \\theta)``.
"""
logZ(pc::ParticleContainer) = logsumexp(pc.logWs) - log(length(pc))

# compute the effective sample size ``1 / ∑ wᵢ²``, where ``wᵢ```are the normalized weights
function effectiveSampleSize(pc :: ParticleContainer)
Ws = getweights(pc)
return inv(sum(abs2, Ws))
end

increase_logweight(pc :: ParticleContainer, t :: Int, logw :: Float64) =
(pc.logWs[t] += logw)

increase_logevidence(pc :: ParticleContainer, logw :: Float64) =
(pc.logE += logw)

increase_logweight!(pc::ParticleContainer, t::Int, logw::Float64) = (pc.logWs[t] += logw)

function resample!(
pc :: ParticleContainer,
Expand Down
22 changes: 17 additions & 5 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,17 @@ function AbstractMCMC.sample_init!(
# create a new particle container
spl.state.particles = pc = ParticleContainer(model, particles)

while consume(pc) !== Val{:done}
# Run particle filter.
logevidence = zero(spl.state.average_logevidence)
isdone = false
while !isdone
resample!(pc, spl.alg.resampler)
isdone = propagate!(pc)
logevidence += logZ(pc)
end
spl.state.average_logevidence = logevidence

return
end

function AbstractMCMC.step!(
Expand All @@ -137,7 +145,7 @@ function AbstractMCMC.step!(
params = tonamedtuple(particle.vi)
lp = getlogp(particle.vi)

return ParticleTransition(params, lp, pc.logE, Ws[iteration])
return ParticleTransition(params, lp, spl.state.average_logevidence, Ws[iteration])
end

####
Expand Down Expand Up @@ -228,8 +236,12 @@ function AbstractMCMC.step!(
pc = ParticleContainer(model, particles)

# run the particle filter
while consume(pc) !== Val{:done}
resample!(pc, spl.alg.resampler, ref_particle)
logevidence = zero(spl.state.average_logevidence)
isdone = false
while !isdone
resample!(pc, spl.alg.resampler)
isdone = propagate!(pc)
logevidence += logZ(pc)
end

# pick a particle to be retained.
Expand All @@ -242,7 +254,7 @@ function AbstractMCMC.step!(
lp = getlogp(spl.state.vi)

# update the master vi.
return ParticleTransition(params, lp, pc.logE, 1.0)
return ParticleTransition(params, lp, logevidence, 1.0)
end

function AbstractMCMC.sample_end!(
Expand Down
4 changes: 3 additions & 1 deletion src/inference/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Inference

using ..Core, ..Utilities
using ..Core
using ..Core: logZ
using ..Utilities
using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo,
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
settrans!, _getvns, getdist, CACHERESET, AbstractSampler,
Expand Down
41 changes: 28 additions & 13 deletions test/core/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Turing, Random
using Turing: ParticleContainer, getweights, resample!,
effectiveSampleSize, Trace, current_trace, VarName,
Sampler, consume, produce, copy, fork
using Turing.Core: logZ
using Turing.Core: logZ, propagate!
using Test

dir = splitdir(splitdir(pathof(Turing))[1])[1]
Expand All @@ -13,9 +13,7 @@ include(dir*"/test/test_utils/AllUtils.jl")
pc = ParticleContainer(x -> x * x, Trace[])
newpc = copy(pc)

@test newpc.logE == pc.logE
@test newpc.logWs == pc.logWs
@test newpc.n_consume == pc.n_consume
@test newpc.logWs == pc.logWs
@test typeof(pc) === typeof(newpc)
end
@turing_testset "particle container" begin
Expand Down Expand Up @@ -44,21 +42,38 @@ include(dir*"/test/test_utils/AllUtils.jl")
particles = [Trace(fpc, model, spl, Turing.VarInfo()) for _ in 1:3]
pc = ParticleContainer(fpc, particles)

@test getweights(pc) == [1/3, 1/3, 1/3]
@test logZ(pc) ≈ log(3)
@test pc.logE ≈ log(1)
# Initial weights and likelihood.
weights = getweights(pc)
lz = logZ(pc)
@test weights == [1/3, 1/3, 1/3]
@test iszero(lz)

@test consume(pc) == log(1)
# Propagate particles.
propagate!(pc)
@test getweights(pc) == weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused here. Shouldn't new weights be different from old weights, since no resampling is performed in propagate!?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I forgot that there was no observation in the model. For the purpose of correctly testing propagate!, maybe consider adding some likelihood in the test?

@test logZ(pc) == lz

# Propagate particles.
propagate!(pc)
@test getweights(pc) == weights
@test logZ(pc) == lz

# Resample particles.
resample!(pc)
@test getweights(pc) == [1/3, 1/3, 1/3]
@test logZ(pc) ≈ log(3)
@test pc.logE ≈ log(1)
@test getweights(pc) == weights
@test logZ(pc) == lz
@test effectiveSampleSize(pc) == 3

@test consume(pc) ≈ log(1)
# Propagate particles.
propagate!(pc)
@test getweights(pc) == weights
@test logZ(pc) == lz

# Resample and propagate particles.
resample!(pc)
@test consume(pc) ≈ log(1)
propagate!(pc)
@test getweights(pc) == weights
@test logZ(pc) == lz
end
@turing_testset "trace" begin
n = Ref(0)
Expand Down
8 changes: 4 additions & 4 deletions test/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ include(dir*"/test/test_utils/AllUtils.jl")
chn1_contd2 = sample(gdemo_default, alg1, 1000; resume_from=chn1, reuse_spl_n=1000)
check_gdemo(chn1_contd2)

chn2 = sample(gdemo_default, alg2, 500; save_state=true)
chn2 = sample(gdemo_default, alg2, 1000; save_state=true)
check_gdemo(chn2)

chn2_contd = sample(gdemo_default, alg2, 500; resume_from=chn2)
chn2_contd = sample(gdemo_default, alg2, 1000; resume_from=chn2)
check_gdemo(chn2_contd)

chn3 = sample(gdemo_default, alg3, 500; save_state=true)
chn3 = sample(gdemo_default, alg3, 1000; save_state=true)
check_gdemo(chn3)

chn3_contd = sample(gdemo_default, alg3, 500; resume_from=chn3)
chn3_contd = sample(gdemo_default, alg3, 1000; resume_from=chn3)
check_gdemo(chn3_contd)
end
@testset "Contexts" begin
Expand Down
4 changes: 2 additions & 2 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ include(dir*"/test/test_utils/AllUtils.jl")
PG(10, :z1, :z2, :z3, :z4),
HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 1500)
check_MoGtest_default(chain, atol = 0.1)
check_MoGtest_default(chain, atol = 0.15)

setadsafe(false)

Expand All @@ -78,7 +78,7 @@ include(dir*"/test/test_utils/AllUtils.jl")
PG(10, :z1, :z2, :z3, :z4),
ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 1500)
check_MoGtest_default(chain, atol = 0.1)
check_MoGtest_default(chain, atol = 0.15)
end

@turing_testset "transitions" begin
Expand Down