@@ -17,70 +17,66 @@ using Test: @test, @test_throws, @testset
1717using Turing
1818
1919@testset " Testing inference.jl with $adbackend " for adbackend in ADUtils. adbackends
20- # Only test threading if 1.3+.
21- if VERSION > v " 1.2"
22- @testset " threaded sampling" begin
23- # Test that chains with the same seed will sample identically.
24- @testset " rng" begin
25- model = gdemo_default
26-
27- # multithreaded sampling with PG causes segfaults on Julia 1.5.4
28- # https://github.com/TuringLang/Turing.jl/issues/1571
29- samplers = @static if VERSION <= v " 1.5.3" || VERSION >= v " 1.6.0"
30- (
31- HMC (0.1 , 7 ; adtype= adbackend),
32- PG (10 ),
33- IS (),
34- MH (),
35- Gibbs (PG (3 , :s ), HMC (0.4 , 8 , :m ; adtype= adbackend)),
36- Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
37- )
38- else
39- (
40- HMC (0.1 , 7 ; adtype= adbackend),
41- IS (),
42- MH (),
43- Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
44- )
45- end
46- for sampler in samplers
47- Random. seed! (5 )
48- chain1 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
20+ @testset " threaded sampling" begin
21+ # Test that chains with the same seed will sample identically.
22+ @testset " rng" begin
23+ model = gdemo_default
24+
25+ # multithreaded sampling with PG causes segfaults on Julia 1.5.4
26+ # https://github.com/TuringLang/Turing.jl/issues/1571
27+ samplers = @static if VERSION <= v " 1.5.3" || VERSION >= v " 1.6.0"
28+ (
29+ HMC (0.1 , 7 ; adtype= adbackend),
30+ PG (10 ),
31+ IS (),
32+ MH (),
33+ Gibbs (PG (3 , :s ), HMC (0.4 , 8 , :m ; adtype= adbackend)),
34+ Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
35+ )
36+ else
37+ (
38+ HMC (0.1 , 7 ; adtype= adbackend),
39+ IS (),
40+ MH (),
41+ Gibbs (HMC (0.1 , 5 , :s ; adtype= adbackend), ESS (:m )),
42+ )
43+ end
44+ for sampler in samplers
45+ Random. seed! (5 )
46+ chain1 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
4947
50- Random. seed! (5 )
51- chain2 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
48+ Random. seed! (5 )
49+ chain2 = sample (model, sampler, MCMCThreads (), 1000 , 4 )
5250
53- @test chain1. value == chain2. value
54- end
51+ @test chain1. value == chain2. value
52+ end
5553
56- # Should also be stable with am explicit RNG
57- seed = 5
58- rng = Random. MersenneTwister (seed)
59- for sampler in samplers
60- Random. seed! (rng, seed)
61- chain1 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
54+ # Should also be stable with am explicit RNG
55+ seed = 5
56+ rng = Random. MersenneTwister (seed)
57+ for sampler in samplers
58+ Random. seed! (rng, seed)
59+ chain1 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
6260
63- Random. seed! (rng, seed)
64- chain2 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
61+ Random. seed! (rng, seed)
62+ chain2 = sample (rng, model, sampler, MCMCThreads (), 1000 , 4 )
6563
66- @test chain1. value == chain2. value
67- end
64+ @test chain1. value == chain2. value
6865 end
66+ end
6967
70- # Smoke test for default sample call.
71- Random. seed! (100 )
72- chain = sample (
73- gdemo_default, HMC (0.1 , 7 ; adtype= adbackend), MCMCThreads (), 1000 , 4
74- )
75- check_gdemo (chain)
68+ # Smoke test for default sample call.
69+ Random. seed! (100 )
70+ chain = sample (gdemo_default, HMC (0.1 , 7 ; adtype= adbackend), MCMCThreads (), 1000 , 4 )
71+ check_gdemo (chain)
7672
77- # run sampler: progress logging should be disabled and
78- # it should return a Chains object
79- sampler = Sampler (HMC (0.1 , 7 ; adtype= adbackend), gdemo_default)
80- chains = sample (gdemo_default, sampler, MCMCThreads (), 1000 , 4 )
81- @test chains isa MCMCChains. Chains
82- end
73+ # run sampler: progress logging should be disabled and
74+ # it should return a Chains object
75+ sampler = Sampler (HMC (0.1 , 7 ; adtype= adbackend), gdemo_default)
76+ chains = sample (gdemo_default, sampler, MCMCThreads (), 1000 , 4 )
77+ @test chains isa MCMCChains. Chains
8378 end
79+
8480 @testset " chain save/resume" begin
8581 Random. seed! (1234 )
8682
0 commit comments