Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ DynamicPPL.getlogp(t::Transition) = t.lp
# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp = getlogp(vi),)

# TODO: Implement additional checks for certain samplers, e.g.
# HMC not supporting discrete parameters.
function _check_model(model::DynamicPPL.Model)
return DynamicPPL.check_model(model; error_on_failure=true)
end
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
return _check_model(model)
end

#########################################
# Default definitions for the interface #
#########################################
Expand All @@ -256,8 +265,10 @@ function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
end

Expand All @@ -280,8 +291,10 @@ function AbstractMCMC.sample(
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), ensemble, N, n_chains;
kwargs...)
end
Expand Down
22 changes: 22 additions & 0 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,28 @@ using Turing
@test all(xs[:, 1] .=== [1, missing, 3])
@test all(xs[:, 2] .=== [missing, 2, 4])
end

@testset "check model" begin
@model function demo_repeated_varname()
x ~ Normal(0, 1)
x ~ Normal(x, 1)
end

@test_throws ErrorException sample(
demo_repeated_varname(), NUTS(), 1000; check_model=true
)
# Make sure that disabling the check also works.
@test (sample(
demo_repeated_varname(), Prior(), 10; check_model=false
); true)

@model function demo_incorrect_missing(y)
y[1:1] ~ MvNormal(zeros(1), 1)
end
@test_throws ErrorException sample(
demo_incorrect_missing([missing]), NUTS(), 1000; check_model=true
)
end
end

end
4 changes: 3 additions & 1 deletion test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
Random.seed!(100)
alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.15)
check_numerical(chain, [:m], [7 / 6]; atol=0.15)
# Be more relaxed with the tolerance of the variance.
check_numerical(chain, [:s], [49 / 24]; atol=0.35)

Random.seed!(100)

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ using Turing

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using Turing
ref = reference(n)

Random.seed!(seed)
chain = sample(model, alg, n)
chain = sample(model, alg, n; check_model=false)
sampled = get(chain, [:a, :b, :lp])

@test vec(sampled.a) == ref.as
Expand Down
21 changes: 16 additions & 5 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,41 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
# c6 = sample(gdemo_default, s6, N)
end
@testset "mh inference" begin
# Set the initial parameters, because if we get unlucky with the initial state,
# these chains are too short to converge to reasonable numbers.
discard_initial = 1000
initial_params = [1.0, 1.0]

Random.seed!(125)
alg = MH()
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH with Gaussian proposal
alg = MH((:s, InverseGamma(2, 3)), (:m, GKernel(1.0)))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH within Gibbs
alg = Gibbs(MH(:m), MH(:s))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MoGtest
gibbs = Gibbs(
CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1)))
)
chain = sample(MoGtest_default, gibbs, 500)
check_MoGtest_default(chain; atol=0.15)
chain = sample(
MoGtest_default,
gibbs,
500;
discard_initial=100,
initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0],
)
check_MoGtest_default(chain; atol=0.2)
end

# Test MH shape passing.
Expand Down