Skip to content

Commit b9db77c

Browse files
cpfifferdevmotion
andauthored
Make MH proposal generation more sane (#1557)
Co-authored-by: David Widmann <[email protected]>
1 parent a4d2778 commit b9db77c

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.15.11"
3+
version = "0.15.12"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/inference/mh.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ struct MH{space, P} <: InferenceAlgorithm
77
end
88

99
proposal(p::AdvancedMH.Proposal) = p
10+
proposal(f::Function) = AdvancedMH.StaticProposal(f)
11+
proposal(d::Distribution) = AdvancedMH.StaticProposal(d)
1012
proposal(cov::AbstractMatrix) = AdvancedMH.RandomWalkProposal(MvNormal(cov))
13+
proposal(x) = error("proposals of type ", typeof(x), " are not supported")
1114

1215
"""
1316
MH(space...)
@@ -162,14 +165,7 @@ function MH(space...)
162165
# Check to see whether it's a pair that specifies a kernel
163166
# or a specific proposal distribution.
164167
push!(prop_syms, s[1])
165-
166-
if s[2] isa AMH.Proposal
167-
push!(props, s[2])
168-
elseif s[2] isa Distribution
169-
push!(props, AMH.StaticProposal(s[2]))
170-
elseif s[2] isa Function
171-
push!(props, AMH.StaticProposal(s[2]))
172-
end
168+
push!(props, proposal(s[2]))
173169
elseif length(space) == 1
174170
# If we hit this block, check to see if it's
175171
# a run-of-the-mill proposal or covariance
@@ -178,11 +174,17 @@ function MH(space...)
178174

179175
# Return early, we got a covariance matrix.
180176
return MH{(), typeof(prop)}(prop)
177+
else
178+
# Try to convert it to a proposal anyways,
179+
# throw an error if not acceptable.
180+
prop = proposal(s)
181+
push!(props, prop)
181182
end
182183
end
183184

184185
proposals = NamedTuple{tuple(prop_syms...)}(tuple(props...))
185186
syms = vcat(syms, prop_syms)
187+
186188
return MH{tuple(syms...), typeof(proposals)}(proposals)
187189
end
188190

test/inference/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
end
128128
end
129129
model = imm(randn(100), 1.0);
130-
sample(model, Gibbs(MH(10, :z), HMC(0.01, 4, :m)), 100);
130+
sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
131131
sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m)), 100);
132132
end
133133
end

test/inference/mh.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,48 @@
106106
check_gdemo(chain2)
107107
end
108108

109+
@turing_testset "gibbs MH proposal matrix" begin
110+
# https://github.com/TuringLang/Turing.jl/issues/1556
111+
112+
# generate data
113+
x = rand(Normal(5, 10), 20)
114+
y = rand(LogNormal(-3, 2), 20)
115+
116+
# Turing model
117+
@model function twomeans(x, y)
118+
# Set Priors
119+
μ ~ MvNormal(2, 3)
120+
σ ~ filldist(Exponential(1), 2)
121+
122+
# Distributions of supplied data
123+
x .~ Normal(μ[1], σ[1])
124+
y .~ LogNormal(μ[2], σ[2])
125+
126+
end
127+
mod = twomeans(x, y)
128+
129+
# generate covariance matrix for RWMH
130+
# with small-valued VC matrix to check if we only see very small steps
131+
vc_μ = convert(Array, 1e-4*I(2))
132+
vc_σ = convert(Array, 1e-4*I(2))
133+
134+
chn = sample(mod,
135+
Gibbs(
136+
MH((, vc_μ)),
137+
MH((, vc_σ)),
138+
), 3_000 # draws
139+
)
140+
141+
142+
chn2 = sample(mod, MH(), 3_000)
143+
144+
# Test that the small variance version is actually smaller.
145+
v1 = var(diff(Array(chn["μ[1]"]), dims=1))
146+
v2 = var(diff(Array(chn2["μ[1]"]), dims=1))
147+
148+
@test v1 < v2
149+
end
150+
109151
@turing_testset "vector of multivariate distributions" begin
110152
@model function test(k)
111153
T = Vector{Vector{Float64}}(undef, k)

0 commit comments

Comments
 (0)