Skip to content
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export InferenceAlgorithm,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -438,6 +439,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
239 changes: 239 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
using DynamicPPL: VarName
using Random: Random
import AbstractMCMC

# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl

"""
GibbsConditional(sym::Symbol, conditional)

A Gibbs sampler component that samples a variable according to a user-provided
analytical conditional distribution.

The `conditional` function should take a `NamedTuple` of conditioned variables and return
a `Distribution` from which to sample the variable `sym`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
λ ~ Gamma(2, inv(3))
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = inv(3)
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the m in the variance term rather be the mean of x?

return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:λ => GibbsConditional(:λ, cond_λ),
:m => GibbsConditional(:m, cond_m)
), 1000)
```
"""
struct GibbsConditional{C} <: InferenceAlgorithm
conditional::C

function GibbsConditional(sym::Symbol, conditional::C) where {C}
return new{C}(conditional)
end
end

# Mark GibbsConditional as a valid Gibbs component
isgibbscomponent(::GibbsConditional) = true

# Required methods for Gibbs constructor
Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this being called somewhere? Might be, but I don't remember having a need for length of samplers.


"""
find_global_varinfo(context, fallback_vi)

Traverse the context stack to find global variable information from
GibbsContext, ConditionContext, FixedContext, etc.
"""
function find_global_varinfo(context, fallback_vi)
# Traverse the entire context stack to find relevant contexts
current_context = context
gibbs_context = nothing
condition_context = nothing
fixed_context = nothing

# Safety check: avoid infinite loops with a maximum depth
max_depth = 20
depth = 0

while current_context !== nothing && depth < max_depth
depth += 1

try
# Use NodeTrait for robust context checking
if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent
if current_context isa GibbsContext
gibbs_context = current_context
elseif current_context isa DynamicPPL.ConditionContext
condition_context = current_context
elseif current_context isa DynamicPPL.FixedContext
fixed_context = current_context
end
# Move to child context
current_context = DynamicPPL.childcontext(current_context)
else
break
end
catch e
# If there's an error traversing contexts, break and use fallback
@debug "Error traversing context at depth $depth: $e"
break
end
end

# Return the most relevant context's varinfo with error handling
try
if gibbs_context !== nothing
return get_global_varinfo(gibbs_context)
elseif condition_context !== nothing
# Check if getvarinfo method exists for ConditionContext
if hasmethod(DynamicPPL.getvarinfo, (typeof(condition_context),))
return DynamicPPL.getvarinfo(condition_context)
end
elseif fixed_context !== nothing
# Check if getvarinfo method exists for FixedContext
if hasmethod(DynamicPPL.getvarinfo, (typeof(fixed_context),))
return DynamicPPL.getvarinfo(fixed_context)
end
end
catch e
@debug "Error accessing varinfo from context: $e"
end

# Fall back to the provided fallback_vi
return fallback_vi
end

"""
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi)

Initialize the GibbsConditional sampler.
"""
function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# GibbsConditional doesn't need any special initialization
# Just return the initial state
return nothing, vi
end

"""
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state)

Perform a step of GibbsConditional sampling.
"""
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
state::DynamicPPL.AbstractVarInfo;
kwargs...,
)
alg = sampler.alg

try
# For GibbsConditional within Gibbs, we need to get all variable values
# Model always has a context field, so we can traverse it directly
global_vi = find_global_varinfo(model.context, state)

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
# Use a safe approach for invlink to avoid linking conflicts
invlinked_global_vi = try
DynamicPPL.invlink(global_vi, model)
catch e
@debug "Failed to invlink global_vi, using as-is: $e"
global_vi
end

condvals_vars = DynamicPPL.values_as(invlinked_global_vi, NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = alg.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in state, handling linking properly
# The Gibbs sampler ensures that state only contains one variable
state_is_linked = try
DynamicPPL.islinked(state, model)
catch e
@debug "Error checking if state is linked: $e"
false
end

if state_is_linked
# If state is linked, we need to unlink, update, then relink
try
unlinked_state = DynamicPPL.invlink(state, model)
updated_state = DynamicPPL.unflatten(unlinked_state, [updated])
new_vi = DynamicPPL.link(updated_state, model)
catch e
@debug "Error in linked state update path: $e, falling back to direct update"
new_vi = DynamicPPL.unflatten(state, [updated])
end
else
# State is not linked, we can update directly
new_vi = DynamicPPL.unflatten(state, [updated])
end

return nothing, new_vi

catch e
# If there's any error in the step, log it and rethrow
@error "Error in GibbsConditional step: $e"
rethrow(e)
end
end

"""
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo)

Update the variable info with new parameters for GibbsConditional.
"""
function setparams_varinfo!!(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
state,
params::DynamicPPL.AbstractVarInfo,
)
# For GibbsConditional, we just return the params as-is since
# the state is nothing and we don't need to update anything
return params
end
114 changes: 114 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,120 @@ end
sampler = Gibbs(:w => HMC(0.05, 10))
@test (sample(model, sampler, 10); true)
end

@testset "GibbsConditional" begin
# Test with the inverse gamma example from the issue
@model function inverse_gdemo(x)
λ ~ Gamma(2, inv(3))
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = inv(3)
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Test basic functionality
@testset "basic sampling" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5]
model = inverse_gdemo(x_obs)

# Test that GibbsConditional works
sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m))
chain = sample(model, sampler, 1000)

# Check that we got the expected variables
@test :λ in names(chain)
@test :m in names(chain)

# Check that the values are reasonable
λ_samples = vec(chain[:λ])
m_samples = vec(chain[:m])

# Given the observed data, we expect certain behavior
@test mean(λ_samples) > 0 # λ should be positive
@test minimum(λ_samples) > 0
@test std(m_samples) < 2.0 # m should be relatively well-constrained
end

# Test mixing with other samplers
@testset "mixed samplers" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0]
model = inverse_gdemo(x_obs)

# Mix GibbsConditional with standard samplers
sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH())
chain = sample(model, sampler, 500)

@test :λ in names(chain)
@test :m in names(chain)
@test size(chain, 1) == 500
end

# Test with a simpler model
@testset "simple normal model" begin
@model function simple_normal(x)
μ ~ Normal(0, 10)
σ ~ truncated(Normal(1, 1); lower=0.01)
for i in 1:length(x)
x[i] ~ Normal(μ, σ)
end
end

# Conditional for μ given σ and x
function cond_μ(c::NamedTuple)
σ = c.σ
x = c.x
n = length(x)
# Prior: μ ~ Normal(0, 10)
# Likelihood: x[i] ~ Normal(μ, σ)
# Posterior: μ ~ Normal(μ_post, σ_post)
prior_var = 100.0 # 10^2
likelihood_var = σ^2 / n
post_var = 1 / (1 / prior_var + n / σ^2)
post_mean = post_var * (0 / prior_var + sum(x) / σ^2)
return Normal(post_mean, sqrt(post_var))
end

Random.seed!(42)
x_obs = randn(10) .+ 2.0 # Data centered around 2
model = simple_normal(x_obs)

sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH())

chain = sample(model, sampler, 1000)

μ_samples = vec(chain[:μ])
@test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean
end

# Test that GibbsConditional is marked as a valid component
@testset "isgibbscomponent" begin
gc = GibbsConditional(:x, c -> Normal(0, 1))
@test Turing.Inference.isgibbscomponent(gc)
end
end
end

end
Loading