Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 96 additions & 63 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,52 +409,74 @@
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi_base::DynamicPPL.AbstractVarInfo;
vi::DynamicPPL.AbstractVarInfo;
initial_params=nothing,
kwargs...,
)
alg = spl.alg
varnames = alg.varnames
samplers = alg.samplers

# Run the model once to get the varnames present + initial values to condition on.
vi = DynamicPPL.VarInfo(rng, model)
if initial_params !== nothing
vi = DynamicPPL.unflatten(vi, initial_params)
end
vi, states = gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
)
return Transition(model, vi), GibbsState(vi, states)
end

# Initialise each component sampler in turn, collect all their states.
states = []
for (varnames_local, sampler_local) in zip(varnames, samplers)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
end
"""
Take the first step of MCMC for the first component sampler, and call the same function
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
)
# End recursion
if isempty(varname_vecs) && isempty(samplers)
return vi, states
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers

# Take initial step.
_, new_state_local = AbstractMCMC.step(
rng,
model_local,
sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)
push!(states, new_state_local)
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames)[:]

Check warning on line 446 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L446

Added line #L446 was not covered by tests
end
return Transition(model, vi), GibbsState(vi, states)

# Construct the conditioned model.
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
rng,
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state)
return gibbs_initialstep_recursive(
rng,
model,
varname_vecs_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
kwargs...,
)
end

function AbstractMCMC.step(
Expand All @@ -471,17 +493,7 @@
states = state.states
@assert length(samplers) == length(state.states)

# TODO: move this into a recursive function so we can unroll when reasonable?
for index in 1:length(samplers)
# Take the inner step.
sampler_local = samplers[index]
state_local = states[index]
varnames_local = varnames[index]
vi, new_state_local = gibbs_step_inner(
rng, model, varnames_local, sampler_local, state_local, vi; kwargs...
)
states = Accessors.setindex(states, new_state_local, index)
end
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
return Transition(model, vi), GibbsState(vi, states)
end

Expand Down Expand Up @@ -605,19 +617,33 @@
return varinfo_local
end

function gibbs_step_inner(
"""
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
function on the tail, until there are no more samplers left.
"""
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames_local,
sampler_local,
state_local,
global_vi;
varname_vecs,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
return global_vi, new_states
end

varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers
state, states_tail... = states

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -628,18 +654,25 @@
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
)
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
)
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))
return new_global_vi, new_state_local

new_states = (new_states..., new_state)
return gibbs_step_recursive(
rng,
model,
varname_vecs_tail,
samplers_tail,
states_tail,
new_global_vi,
new_states;
kwargs...,
)
end
Loading