-
Notifications
You must be signed in to change notification settings - Fork 228
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AoifeHughes
wants to merge
11
commits into
main
Choose a base branch
from
gibbs-sampler
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Gibbs sampler #2647
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
c0158ea
Add GibbsConditional sampler and corresponding tests
a972b5a
clarified comment
bdb7f73
Merge branch 'main' into gibbs-sampler
mhauru c3cc773
add MHs suggestions
714c1e8
formatter
97c571d
Merge branch 'gibbs-sampler' of github.com:TuringLang/Turing.jl into …
94b723d
fixed exporting thing
891ac14
Merge branch 'main' into gibbs-sampler
AoifeHughes 2058ae5
Refactor Gibbs sampler to use inverse of parameters for Gamma distrib…
b0812a3
removed file added by mistake
d910312
Add safety checks and error handling in find_global_varinfo and Abstr…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
using DynamicPPL: VarName | ||
using Random: Random | ||
import AbstractMCMC | ||
|
||
# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl | ||
|
||
""" | ||
GibbsConditional(sym::Symbol, conditional) | ||
|
||
A Gibbs sampler component that samples a variable according to a user-provided | ||
analytical conditional distribution. | ||
|
||
The `conditional` function should take a `NamedTuple` of conditioned variables and return | ||
a `Distribution` from which to sample the variable `sym`. | ||
|
||
# Examples | ||
|
||
```julia | ||
# Define a model | ||
@model function inverse_gdemo(x) | ||
λ ~ Gamma(2, inv(3)) | ||
m ~ Normal(0, sqrt(1 / λ)) | ||
for i in 1:length(x) | ||
x[i] ~ Normal(m, sqrt(1 / λ)) | ||
end | ||
end | ||
|
||
# Define analytical conditionals | ||
function cond_λ(c::NamedTuple) | ||
a = 2.0 | ||
b = inv(3) | ||
m = c.m | ||
x = c.x | ||
n = length(x) | ||
a_new = a + (n + 1) / 2 | ||
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||
return Gamma(a_new, 1 / b_new) | ||
end | ||
|
||
function cond_m(c::NamedTuple) | ||
λ = c.λ | ||
x = c.x | ||
n = length(x) | ||
m_mean = sum(x) / (n + 1) | ||
m_var = 1 / (λ * (n + 1)) | ||
return Normal(m_mean, sqrt(m_var)) | ||
end | ||
|
||
# Sample using GibbsConditional | ||
model = inverse_gdemo([1.0, 2.0, 3.0]) | ||
chain = sample(model, Gibbs( | ||
:λ => GibbsConditional(:λ, cond_λ), | ||
:m => GibbsConditional(:m, cond_m) | ||
), 1000) | ||
``` | ||
""" | ||
struct GibbsConditional{C} <: InferenceAlgorithm | ||
conditional::C | ||
|
||
function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||
return new{C}(conditional) | ||
end | ||
end | ||
|
||
# Mark GibbsConditional as a valid Gibbs component | ||
isgibbscomponent(::GibbsConditional) = true | ||
|
||
# Required methods for Gibbs constructor | ||
Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this being called somewhere? Might be, but I don't remember having a need for |
||
|
||
""" | ||
find_global_varinfo(context, fallback_vi) | ||
|
||
Traverse the context stack to find global variable information from | ||
GibbsContext, ConditionContext, FixedContext, etc. | ||
""" | ||
function find_global_varinfo(context, fallback_vi) | ||
# Traverse the entire context stack to find relevant contexts | ||
current_context = context | ||
gibbs_context = nothing | ||
condition_context = nothing | ||
fixed_context = nothing | ||
|
||
# Safety check: avoid infinite loops with a maximum depth | ||
max_depth = 20 | ||
depth = 0 | ||
|
||
while current_context !== nothing && depth < max_depth | ||
depth += 1 | ||
|
||
try | ||
# Use NodeTrait for robust context checking | ||
if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent | ||
if current_context isa GibbsContext | ||
gibbs_context = current_context | ||
elseif current_context isa DynamicPPL.ConditionContext | ||
condition_context = current_context | ||
elseif current_context isa DynamicPPL.FixedContext | ||
fixed_context = current_context | ||
end | ||
# Move to child context | ||
current_context = DynamicPPL.childcontext(current_context) | ||
else | ||
break | ||
end | ||
catch e | ||
# If there's an error traversing contexts, break and use fallback | ||
@debug "Error traversing context at depth $depth: $e" | ||
break | ||
end | ||
end | ||
|
||
# Return the most relevant context's varinfo with error handling | ||
try | ||
if gibbs_context !== nothing | ||
return get_global_varinfo(gibbs_context) | ||
elseif condition_context !== nothing | ||
# Check if getvarinfo method exists for ConditionContext | ||
if hasmethod(DynamicPPL.getvarinfo, (typeof(condition_context),)) | ||
return DynamicPPL.getvarinfo(condition_context) | ||
end | ||
elseif fixed_context !== nothing | ||
# Check if getvarinfo method exists for FixedContext | ||
if hasmethod(DynamicPPL.getvarinfo, (typeof(fixed_context),)) | ||
return DynamicPPL.getvarinfo(fixed_context) | ||
end | ||
end | ||
catch e | ||
@debug "Error accessing varinfo from context: $e" | ||
end | ||
|
||
# Fall back to the provided fallback_vi | ||
return fallback_vi | ||
end | ||
|
||
""" | ||
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||
|
||
Initialize the GibbsConditional sampler. | ||
""" | ||
function DynamicPPL.initialstep( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
vi::DynamicPPL.AbstractVarInfo; | ||
kwargs..., | ||
) | ||
# GibbsConditional doesn't need any special initialization | ||
# Just return the initial state | ||
return nothing, vi | ||
end | ||
|
||
""" | ||
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) | ||
|
||
Perform a step of GibbsConditional sampling. | ||
""" | ||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
state::DynamicPPL.AbstractVarInfo; | ||
kwargs..., | ||
) | ||
alg = sampler.alg | ||
|
||
try | ||
# For GibbsConditional within Gibbs, we need to get all variable values | ||
# Model always has a context field, so we can traverse it directly | ||
global_vi = find_global_varinfo(model.context, state) | ||
|
||
# Extract conditioned values as a NamedTuple | ||
# Include both random variables and observed data | ||
# Use a safe approach for invlink to avoid linking conflicts | ||
invlinked_global_vi = try | ||
DynamicPPL.invlink(global_vi, model) | ||
catch e | ||
@debug "Failed to invlink global_vi, using as-is: $e" | ||
global_vi | ||
end | ||
|
||
condvals_vars = DynamicPPL.values_as(invlinked_global_vi, NamedTuple) | ||
condvals_obs = NamedTuple{keys(model.args)}(model.args) | ||
condvals = merge(condvals_vars, condvals_obs) | ||
|
||
# Get the conditional distribution | ||
conddist = alg.conditional(condvals) | ||
|
||
# Sample from the conditional distribution | ||
updated = rand(rng, conddist) | ||
|
||
# Update the variable in state, handling linking properly | ||
# The Gibbs sampler ensures that state only contains one variable | ||
state_is_linked = try | ||
DynamicPPL.islinked(state, model) | ||
catch e | ||
@debug "Error checking if state is linked: $e" | ||
false | ||
end | ||
|
||
if state_is_linked | ||
# If state is linked, we need to unlink, update, then relink | ||
try | ||
unlinked_state = DynamicPPL.invlink(state, model) | ||
updated_state = DynamicPPL.unflatten(unlinked_state, [updated]) | ||
new_vi = DynamicPPL.link(updated_state, model) | ||
catch e | ||
@debug "Error in linked state update path: $e, falling back to direct update" | ||
new_vi = DynamicPPL.unflatten(state, [updated]) | ||
end | ||
else | ||
# State is not linked, we can update directly | ||
new_vi = DynamicPPL.unflatten(state, [updated]) | ||
end | ||
|
||
return nothing, new_vi | ||
|
||
catch e | ||
# If there's any error in the step, log it and rethrow | ||
@error "Error in GibbsConditional step: $e" | ||
rethrow(e) | ||
end | ||
end | ||
|
||
""" | ||
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) | ||
|
||
Update the variable info with new parameters for GibbsConditional. | ||
""" | ||
function setparams_varinfo!!( | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
state, | ||
params::DynamicPPL.AbstractVarInfo, | ||
) | ||
# For GibbsConditional, we just return the params as-is since | ||
# the state is nothing and we don't need to update anything | ||
return params | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the
m
in the variance term rather be the mean ofx
?