Skip to content

Commit b9f180a

Browse files
torfjeldedevmotion
andauthored
Fix for HMCs dot_assume (#1758)
* fixed dot_assume for hmc * copy-pasted tests from dynamicppl integration tests * inspecting what in the world is going on with tests * trying again * skip failing test for TrackerAD * bump patch version * fixed typo in tests * Rename `Turing.Core` to `Turing.Essential` * Deprecate Turing.Core Co-authored-by: Tor Erlend Fjelde <[email protected]> * fixed a numerical test * version bump Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 9ad85c2 commit b9f180a

File tree

3 files changed

+346
-2
lines changed

3 files changed

+346
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.19.4"
3+
version = "0.19.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/inference/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ function DynamicPPL.dot_assume(
503503
var::AbstractArray,
504504
vi,
505505
)
506-
DynamicPPL.updategid!(Ref(vi), vns, Ref(spl))
506+
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
507507
return DynamicPPL.dot_assume(dists, var, vns, vi)
508508
end
509509

test/inference/Inference.jl

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,348 @@
151151
@test range(chain) == range(6; step=2, length=10)
152152
end
153153
end
154+
155+
# Copy-paste from integration tests in DynamicPPL.
156+
@testset "assume" begin
157+
@model function test_assume()
158+
x ~ Bernoulli(1)
159+
y ~ Bernoulli(x / 2)
160+
return x, y
161+
end
162+
163+
smc = SMC()
164+
pg = PG(10)
165+
166+
res1 = sample(test_assume(), smc, 1000)
167+
res2 = sample(test_assume(), pg, 1000)
168+
169+
check_numerical(res1, [:y], [0.5]; atol=0.1)
170+
check_numerical(res2, [:y], [0.5]; atol=0.1)
171+
172+
# Check that all xs are 1.
173+
@test all(isone, res1[:x])
174+
@test all(isone, res2[:x])
175+
end
176+
@testset "beta binomial" begin
177+
prior = Beta(2, 2)
178+
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
179+
exact = Beta(prior.α + sum(obs), prior.β + length(obs) - sum(obs))
180+
meanp = exact.α / (exact.α + exact.β)
181+
182+
@model function testbb(obs)
183+
p ~ Beta(2, 2)
184+
x ~ Bernoulli(p)
185+
for i in 1:length(obs)
186+
obs[i] ~ Bernoulli(p)
187+
end
188+
return p, x
189+
end
190+
191+
smc = SMC()
192+
pg = PG(10)
193+
gibbs = Gibbs(HMC(0.2, 3, :p), PG(10, :x))
194+
195+
chn_s = sample(testbb(obs), smc, 1000)
196+
chn_p = sample(testbb(obs), pg, 2000)
197+
chn_g = sample(testbb(obs), gibbs, 1500)
198+
199+
check_numerical(chn_s, [:p], [meanp]; atol=0.05)
200+
check_numerical(chn_p, [:x], [meanp]; atol=0.1)
201+
check_numerical(chn_g, [:x], [meanp]; atol=0.1)
202+
end
203+
@testset "forbid global" begin
204+
xs = [1.5 2.0]
205+
# xx = 1
206+
207+
@model function fggibbstest(xs)
208+
s ~ InverseGamma(2, 3)
209+
m ~ Normal(0, sqrt(s))
210+
# xx ~ Normal(m, sqrt(s)) # this is illegal
211+
212+
for i in 1:length(xs)
213+
xs[i] ~ Normal(m, sqrt(s))
214+
# for xx in xs
215+
# xx ~ Normal(m, sqrt(s))
216+
end
217+
return s, m
218+
end
219+
220+
gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m))
221+
chain = sample(fggibbstest(xs), gibbs, 2)
222+
end
223+
@testset "new grammar" begin
224+
x = Float64[1 2]
225+
226+
@model function gauss(x)
227+
priors = TArray{Float64}(2)
228+
priors[1] ~ InverseGamma(2, 3) # s
229+
priors[2] ~ Normal(0, sqrt(priors[1])) # m
230+
for i in 1:length(x)
231+
x[i] ~ Normal(priors[2], sqrt(priors[1]))
232+
end
233+
return priors
234+
end
235+
236+
chain = sample(gauss(x), PG(10), 10)
237+
chain = sample(gauss(x), SMC(), 10)
238+
239+
@model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV}
240+
priors = TV(undef, 2)
241+
priors[1] ~ InverseGamma(2, 3) # s
242+
priors[2] ~ Normal(0, sqrt(priors[1])) # m
243+
for i in 1:length(x)
244+
x[i] ~ Normal(priors[2], sqrt(priors[1]))
245+
end
246+
return priors
247+
end
248+
249+
chain = sample(gauss2(; x=x), PG(10), 10)
250+
chain = sample(gauss2(; x=x), SMC(), 10)
251+
252+
chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
253+
chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)
254+
end
255+
@testset "new interface" begin
256+
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
257+
258+
@model function newinterface(obs)
259+
p ~ Beta(2, 2)
260+
for i in 1:length(obs)
261+
obs[i] ~ Bernoulli(p)
262+
end
263+
return p
264+
end
265+
266+
chain = sample(
267+
newinterface(obs), HMC{Turing.ForwardDiffAD{2}}(0.75, 3, :p, :x), 100
268+
)
269+
end
270+
@testset "no return" begin
271+
@model function noreturn(x)
272+
s ~ InverseGamma(2, 3)
273+
m ~ Normal(0, sqrt(s))
274+
for i in 1:length(x)
275+
x[i] ~ Normal(m, sqrt(s))
276+
end
277+
end
278+
279+
chain = sample(noreturn([1.5 2.0]), HMC(0.1, 10), 4000)
280+
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6])
281+
end
282+
@testset "observe" begin
283+
@model function test()
284+
z ~ Normal(0, 1)
285+
x ~ Bernoulli(1)
286+
1 ~ Bernoulli(x / 2)
287+
0 ~ Bernoulli(x / 2)
288+
return x
289+
end
290+
291+
is = IS()
292+
smc = SMC()
293+
pg = PG(10)
294+
295+
res_is = sample(test(), is, 10000)
296+
res_smc = sample(test(), smc, 1000)
297+
res_pg = sample(test(), pg, 100)
298+
299+
@test all(isone, res_is[:x])
300+
@test res_is.logevidence 2 * log(0.5)
301+
302+
@test all(isone, res_smc[:x])
303+
@test res_smc.logevidence 2 * log(0.5)
304+
305+
@test all(isone, res_pg[:x])
306+
end
307+
@testset "sample" begin
308+
alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s))
309+
chn = sample(gdemo_default, alg, 1000)
310+
end
311+
@testset "vectorization @." begin
312+
# https://github.com/FluxML/Tracker.jl/issues/119
313+
if Turing.Core.ADBackend() !== Turing.Core.TrackerAD
314+
@model function vdemo1(x)
315+
s ~ InverseGamma(2, 3)
316+
m ~ Normal(0, sqrt(s))
317+
@. x ~ Normal(m, sqrt(s))
318+
return s, m
319+
end
320+
321+
alg = HMC(0.01, 5)
322+
x = randn(100)
323+
res = sample(vdemo1(x), alg, 250)
324+
325+
@model function vdemo1b(x)
326+
s ~ InverseGamma(2, 3)
327+
m ~ Normal(0, sqrt(s))
328+
@. x ~ Normal(m, $(sqrt(s)))
329+
return s, m
330+
end
331+
332+
res = sample(vdemo1b(x), alg, 250)
333+
334+
@model function vdemo2(x)
335+
μ ~ MvNormal(zeros(size(x, 1)), I)
336+
@. x ~ $(MvNormal(μ, I))
337+
end
338+
339+
D = 2
340+
alg = HMC(0.01, 5)
341+
res = sample(vdemo2(randn(D, 100)), alg, 250)
342+
343+
# Vector assumptions
344+
N = 10
345+
setchunksize(N)
346+
alg = HMC(0.2, 4)
347+
348+
@model function vdemo3()
349+
x = Vector{Real}(undef, N)
350+
for i in 1:N
351+
x[i] ~ Normal(0, sqrt(4))
352+
end
353+
end
354+
355+
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
356+
357+
# Test for vectorize UnivariateDistribution
358+
@model function vdemo4()
359+
x = Vector{Real}(undef, N)
360+
@. x ~ Normal(0, 2)
361+
end
362+
363+
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
364+
365+
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)
366+
367+
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
368+
369+
println("Time for")
370+
println(" Loop : ", t_loop)
371+
println(" Vec : ", t_vec)
372+
println(" Mv : ", t_mv)
373+
374+
# Transformed test
375+
@model function vdemo6()
376+
x = Vector{Real}(undef, N)
377+
@. x ~ InverseGamma(2, 3)
378+
end
379+
380+
sample(vdemo6(), alg, 1000)
381+
382+
N = 3
383+
@model function vdemo7()
384+
x = Array{Real}(undef, N, N)
385+
@. x ~ [InverseGamma(2, 3) for i in 1:N]
386+
end
387+
388+
sample(vdemo7(), alg, 1000)
389+
end
390+
end
391+
@testset "vectorization .~" begin
392+
@model function vdemo1(x)
393+
s ~ InverseGamma(2, 3)
394+
m ~ Normal(0, sqrt(s))
395+
x .~ Normal(m, sqrt(s))
396+
return s, m
397+
end
398+
399+
alg = HMC(0.01, 5)
400+
x = randn(100)
401+
res = sample(vdemo1(x), alg, 250)
402+
403+
@model function vdemo2(x)
404+
μ ~ MvNormal(zeros(size(x, 1)), I)
405+
return x .~ MvNormal(μ, I)
406+
end
407+
408+
D = 2
409+
alg = HMC(0.01, 5)
410+
res = sample(vdemo2(randn(D, 100)), alg, 250)
411+
412+
# Vector assumptions
413+
N = 10
414+
setchunksize(N)
415+
alg = HMC(0.2, 4)
416+
417+
@model function vdemo3()
418+
x = Vector{Real}(undef, N)
419+
for i in 1:N
420+
x[i] ~ Normal(0, sqrt(4))
421+
end
422+
end
423+
424+
t_loop = @elapsed res = sample(vdemo3(), alg, 1000)
425+
426+
# Test for vectorize UnivariateDistribution
427+
@model function vdemo4()
428+
x = Vector{Real}(undef, N)
429+
return x .~ Normal(0, 2)
430+
end
431+
432+
t_vec = @elapsed res = sample(vdemo4(), alg, 1000)
433+
434+
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)
435+
436+
t_mv = @elapsed res = sample(vdemo5(), alg, 1000)
437+
438+
println("Time for")
439+
println(" Loop : ", t_loop)
440+
println(" Vec : ", t_vec)
441+
println(" Mv : ", t_mv)
442+
443+
# Transformed test
444+
@model function vdemo6()
445+
x = Vector{Real}(undef, N)
446+
return x .~ InverseGamma(2, 3)
447+
end
448+
449+
sample(vdemo6(), alg, 1000)
450+
451+
@model function vdemo7()
452+
x = Array{Real}(undef, N, N)
453+
return x .~ [InverseGamma(2, 3) for i in 1:N]
454+
end
455+
456+
sample(vdemo7(), alg, 1000)
457+
end
458+
@testset "Type parameters" begin
459+
N = 10
460+
setchunksize(N)
461+
alg = HMC(0.01, 5)
462+
x = randn(1000)
463+
@model function vdemo1(::Type{T}=Float64) where {T}
464+
x = Vector{T}(undef, N)
465+
for i in 1:N
466+
x[i] ~ Normal(0, sqrt(4))
467+
end
468+
end
469+
470+
t_loop = @elapsed res = sample(vdemo1(), alg, 250)
471+
t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250)
472+
473+
vdemo1kw(; T) = vdemo1(T)
474+
t_loop = @elapsed res = sample(vdemo1kw(; T=Float64), alg, 250)
475+
476+
@model function vdemo2(::Type{T}=Float64) where {T<:Real}
477+
x = Vector{T}(undef, N)
478+
@. x ~ Normal(0, 2)
479+
end
480+
481+
t_vec = @elapsed res = sample(vdemo2(), alg, 250)
482+
t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250)
483+
484+
vdemo2kw(; T) = vdemo2(T)
485+
t_vec = @elapsed res = sample(vdemo2kw(; T=Float64), alg, 250)
486+
487+
@model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector}
488+
x = TV(undef, N)
489+
@. x ~ InverseGamma(2, 3)
490+
end
491+
492+
sample(vdemo3(), alg, 250)
493+
sample(vdemo3(Vector{Float64}), alg, 250)
494+
495+
vdemo3kw(; T) = vdemo3(T)
496+
sample(vdemo3kw(; T=Vector{Float64}), alg, 250)
497+
end
154498
end

0 commit comments

Comments
 (0)