Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ export @model,
forkr,
current_trace,
getweights,
getweight,
effectiveSampleSize,
increase_logweight!,
propagate!,
resample!,
sweep!,
ResampleWithESSThreshold,
ADBackend,
setadbackend,
Expand Down
215 changes: 166 additions & 49 deletions src/core/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ function Base.copy(trace::Trace)
end

# NOTE: this function is called by `forkr`
function Trace(f::Function, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi));
ctask = CTask(() -> (res = f(); produce(Val{:done}); res))
function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
ctask = CTask() do
res = f()
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
Expand All @@ -31,10 +35,15 @@ function Trace(f::Function, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res.ctask = ctask
return res
end

function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi));
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
reset_num_produce!(res.vi)
ctask = CTask(() -> (vi_new = m(vi, spl); produce(Val{:done}); vi_new))
ctask = CTask() do
res = m(vi, spl)
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
Expand Down Expand Up @@ -108,65 +117,70 @@ function Base.copy(pc::ParticleContainer)
end

"""
propagate!(pc::ParticleContainer)
reset_logweights!(pc::ParticleContainer)

Run particle filter for one step and check if the final time step is reached.
Reset all unnormalized logarithmic weights to zero.
"""
function propagate!(pc::ParticleContainer)
# normalisation factor: 1/N
n = length(pc)

particles = collect(pc)
numdone = 0
for i in 1:n
p = particles[i]
score = Libtask.consume(p)
if score isa Real
score += getlogp(p.vi)
resetlogp!(p.vi)
increase_logweight!(pc, i, Float64(score))
elseif score == Val{:done}
numdone += 1
else
error("[consume]: error in running particle filter.")
end
end

# Check if all particles are propagated to the final time point.
numdone == n && return true
function reset_logweights!(pc::ParticleContainer)
fill!(pc.logWs, 0.0)
return pc
end

# The posterior for models with random number of observations is not well-defined.
if numdone != 0
error("mis-aligned execution traces: # particles = ", n,
" # completed trajectories = ", numdone,
". Please make sure the number of observations is NOT random.")
end
"""
increase_logweight!(pc::ParticleContainer, i::Int, x)

return false
Increase the unnormalized logarithmic weight of the `i`th particle with `x`.
"""
function increase_logweight!(pc::ParticleContainer, i, logw)
pc.logWs[i] += logw
return pc
end

# compute the normalized weights
"""
getweights(pc::ParticleContainer)

Compute the normalized weights of the particles.
"""
getweights(pc::ParticleContainer) = softmax(pc.logWs)

"""
getweight(pc::ParticleContainer, i)

Compute the normalized weight of the `i`th particle.
"""
getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc))

"""
logZ(pc::ParticleContainer)

Return the estimate of the log-likelihood ``p(y_t | y_{1:(t-1)}, \\theta)``.
Return the logarithm of the normalizing constant of the unnormalized logarithmic weights.
"""
logZ(pc::ParticleContainer) = logsumexp(pc.logWs) - log(length(pc))
logZ(pc::ParticleContainer) = logsumexp(pc.logWs)

# compute the effective sample size ``1 / ∑ wᵢ²``, where ``wᵢ```are the normalized weights
function effectiveSampleSize(pc :: ParticleContainer)
"""
effectiveSampleSize(pc::ParticleContainer)

Compute the effective sample size ``1 / ∑ wᵢ²``, where ``wᵢ```are the normalized weights.
"""
function effectiveSampleSize(pc::ParticleContainer)
Ws = getweights(pc)
return inv(sum(abs2, Ws))
end

increase_logweight!(pc::ParticleContainer, t::Int, logw::Float64) = (pc.logWs[t] += logw)
"""
resample_propagate!(pc::ParticleContainer[, randcat = resample_systematic, ref = nothing;
weights = getweights(pc)])

Resample and propagate the particles in `pc`.

function resample!(
pc :: ParticleContainer,
randcat :: Function = Turing.Inference.resample_systematic,
ref :: Union{Particle, Nothing} = nothing;
Function `randcat` is used for sampling ancestor indices from the categorical distribution
of the particle `weights`. For Particle Gibbs sampling, one can provide a reference particle
`ref` that is ensured to survive the resampling step.
"""
function resample_propagate!(
pc::ParticleContainer,
randcat = Turing.Inference.resample_systematic,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
# check that weights are not NaN
Expand Down Expand Up @@ -212,11 +226,114 @@ function resample!(

# replace particles and log weights in the container with new particles and weights
pc.vals = children
pc.logWs = zeros(n)
reset_logweights!(pc)

pc
end

"""
reweight!(pc::ParticleContainer)

Check if the final time step is reached, and otherwise reweight the particles by
considering the next observation.
"""
function reweight!(pc::ParticleContainer)
n = length(pc)

particles = collect(pc)
numdone = 0
for i in 1:n
p = particles[i]

# Obtain ``\\log p(yₜ | y₁, …, yₜ₋₁, x₁, …, xₜ, θ₁, …, θₜ)``, or `nothing` if the
# the execution of the model is finished.
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
score = Libtask.consume(p)

if score === nothing
numdone += 1
else
# Increase the unnormalized logarithmic weights, accounting for the variables
# of other samplers.
increase_logweight!(pc, i, score + getlogp(p.vi))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a side note, getlotp(p.vi) will always return 0, since the assume and observe functions for particle samplers does not modify vi.logp by default. This doesn't affect correctness, but worth to pay attention.

See:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not completely true (or maybe I misunderstand you), in this line getlogp can actually return nonzero values due to

acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
. However, since we call resetlogp! in one of the following lines, this won't show up in the saved transitions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must have missed that line, thanks for the pointer!


# Reset the accumulator of the log probability in the model so that we can
# accumulate log probabilities of variables of other samplers until the next
# observation.
resetlogp!(p.vi)
end
end

# Check if all particles are propagated to the final time point.
numdone == n && return true

# The posterior for models with random number of observations is not well-defined.
if numdone != 0
error("mis-aligned execution traces: # particles = ", n,
" # completed trajectories = ", numdone,
". Please make sure the number of observations is NOT random.")
end

return false
end

"""
sweep!(pc::ParticleContainer, resampler)

Perform a particle sweep and return an unbiased estimate of the log evidence.

The resampling steps use the given `resampler`.

# Reference

Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers.
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436.
"""
function sweep!(pc::ParticleContainer, resampler)
# Initial step:

# Resample and propagate particles.
resample_propagate!(pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
# Usually it is equal to the number of particles in the beginning but this
# implementation covers also the unlikely case of a particle container that is
# initialized with non-zero logarithmic weights.
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)

# Compute the estimate of the log evidence ``\\log p(y₁)``.
logevidence = logZ1 - logZ0

# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)

# Compute the estimate of the log evidence ``\\log p(y₁, …, yₜ)``.
logevidence += logZ1 - logZ0
end

return logevidence
end

struct ResampleWithESSThreshold{R, T<:Real}
resampler::R
threshold::T
Expand All @@ -226,7 +343,7 @@ function ResampleWithESSThreshold(resampler = Turing.Inference.resample_systemat
ResampleWithESSThreshold(resampler, 0.5)
end

function resample!(
function resample_propagate!(
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
Expand All @@ -236,7 +353,7 @@ function resample!(
ess = inv(sum(abs2, weights))

if ess ≤ resampler.threshold * length(pc)
resample!(pc, resampler.resampler, ref; weights = weights)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
end

pc
Expand Down
Loading