|
1 | 1 | @testset "hmc.jl" begin |
| 2 | + # Set a seed |
| 3 | + rng = StableRNG(123) |
2 | 4 | @numerical_testset "constrained bounded" begin |
3 | | - # Set a seed |
4 | | - Random.seed!(5) |
5 | | - |
6 | 5 | obs = [0,1,0,1,1,1,1,1,1,1] |
7 | 6 |
|
8 | 7 | @model constrained_test(obs) = begin |
|
14 | 13 | end |
15 | 14 |
|
16 | 15 | chain = sample( |
| 16 | + rng, |
17 | 17 | constrained_test(obs), |
18 | 18 | HMC(1.5, 3),# using a large step size (1.5) |
19 | 19 | 1000) |
|
33 | 33 | end |
34 | 34 |
|
35 | 35 | chain = sample( |
| 36 | + rng, |
36 | 37 | constrained_simplex_test(obs12), |
37 | 38 | HMC(0.75, 2), |
38 | 39 | 1000) |
39 | 40 |
|
40 | 41 | check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015) |
41 | 42 | end |
42 | 43 | @numerical_testset "hmc reverse diff" begin |
43 | | - Random.seed!(1) |
44 | 44 | alg = HMC(0.1, 10) |
45 | | - res = sample(gdemo_default, alg, 4000) |
| 45 | + res = sample(rng, gdemo_default, alg, 4000) |
46 | 46 | check_gdemo(res, rtol=0.1) |
47 | 47 | end |
48 | 48 | @turing_testset "matrix support" begin |
|
53 | 53 | model_f = hmcmatrixsup() |
54 | 54 | n_samples = 1_000 |
55 | 55 | vs = map(1:3) do _ |
56 | | - chain = sample(model_f, HMC(0.15, 7), n_samples) |
| 56 | + chain = sample(rng, model_f, HMC(0.15, 7), n_samples) |
57 | 57 | r = reshape(Array(group(chain, :v)), n_samples, 2, 2) |
58 | 58 | reshape(mean(r; dims = 1), 2, 2) |
59 | 59 | end |
|
103 | 103 | end |
104 | 104 |
|
105 | 105 | # Sampling |
106 | | - chain = sample(bnn(ts), HMC(0.1, 5), 10) |
| 106 | + chain = sample(rng, bnn(ts), HMC(0.1, 5), 10) |
107 | 107 | end |
108 | 108 | @numerical_testset "hmcda inference" begin |
109 | | - Random.seed!(12345) |
110 | 109 | alg1 = HMCDA(1000, 0.8, 0.015) |
111 | 110 | # alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m), HMC(0.25, 3, :s)) |
112 | 111 | alg3 = Gibbs(PG(10, :s), HMCDA(200, 0.8, 0.005, :m)) |
113 | 112 | # alg3 = Gibbs(HMC(0.25, 3, :m), PG(30, 3, :s)) |
114 | 113 | # alg3 = PG(50, 2000) |
115 | 114 |
|
116 | | - res1 = sample(gdemo_default, alg1, 3000) |
| 115 | + res1 = sample(rng, gdemo_default, alg1, 3000) |
117 | 116 | check_gdemo(res1) |
118 | 117 |
|
119 | 118 | # res2 = sample(gdemo([1.5, 2.0]), alg2) |
120 | 119 | # |
121 | 120 | # @test mean(res2[:s]) ≈ 49/24 atol=0.2 |
122 | 121 | # @test mean(res2[:m]) ≈ 7/6 atol=0.2 |
123 | 122 |
|
124 | | - res3 = sample(gdemo_default, alg3, 2000) |
| 123 | + res3 = sample(rng, gdemo_default, alg3, 2000) |
125 | 124 | check_gdemo(res3) |
126 | 125 | end |
127 | 126 |
|
|
146 | 145 | end |
147 | 146 | @numerical_testset "nuts inference" begin |
148 | 147 | alg = NUTS(1000, 0.8) |
149 | | - res = sample(gdemo_default, alg, 6000) |
| 148 | + res = sample(rng, gdemo_default, alg, 6000) |
150 | 149 | check_gdemo(res) |
151 | 150 | end |
152 | 151 | @turing_testset "nuts constructor" begin |
|
165 | 164 | @turing_testset "check discard" begin |
166 | 165 | alg = NUTS(100, 0.8) |
167 | 166 |
|
168 | | - c1 = sample(gdemo_default, alg, 500, discard_adapt = true) |
169 | | - c2 = sample(gdemo_default, alg, 500, discard_adapt = false) |
| 167 | + c1 = sample(rng, gdemo_default, alg, 500, discard_adapt = true) |
| 168 | + c2 = sample(rng, gdemo_default, alg, 500, discard_adapt = false) |
170 | 169 |
|
171 | 170 | @test size(c1, 1) == 500 |
172 | 171 | @test size(c2, 1) == 500 |
|
175 | 174 | alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s)) |
176 | 175 | alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s)) |
177 | 176 | alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s)) |
178 | | - @test sample(gdemo_default, alg1, 300) isa Chains |
179 | | - @test sample(gdemo_default, alg2, 300) isa Chains |
180 | | - @test sample(gdemo_default, alg3, 300) isa Chains |
| 177 | + @test sample(rng, gdemo_default, alg1, 300) isa Chains |
| 178 | + @test sample(rng, gdemo_default, alg2, 300) isa Chains |
| 179 | + @test sample(rng, gdemo_default, alg3, 300) isa Chains |
181 | 180 | end |
182 | 181 |
|
183 | 182 | @turing_testset "Regression tests" begin |
|
186 | 185 | m = Matrix{T}(undef, 2, 3) |
187 | 186 | m .~ MvNormal(zeros(2), I) |
188 | 187 | end |
189 | | - @test sample(mwe1(), HMC(0.2, 4), 1_000) isa Chains |
| 188 | + @test sample(rng, mwe1(), HMC(0.2, 4), 1_000) isa Chains |
190 | 189 |
|
191 | 190 | @model function mwe2(::Type{T} = Matrix{Float64}) where T |
192 | 191 | m = T(undef, 2, 3) |
193 | 192 | m .~ MvNormal(zeros(2), I) |
194 | 193 | end |
195 | | - @test sample(mwe2(), HMC(0.2, 4), 1_000) isa Chains |
| 194 | + @test sample(rng, mwe2(), HMC(0.2, 4), 1_000) isa Chains |
196 | 195 |
|
197 | 196 | # https://github.com/TuringLang/Turing.jl/issues/1308 |
198 | 197 | @model function mwe3(::Type{T} = Array{Float64}) where T |
199 | 198 | m = T(undef, 2, 3) |
200 | 199 | m .~ MvNormal(zeros(2), I) |
201 | 200 | end |
202 | | - @test sample(mwe3(), HMC(0.2, 4), 1_000) isa Chains |
| 201 | + @test sample(rng, mwe3(), HMC(0.2, 4), 1_000) isa Chains |
203 | 202 | end |
204 | 203 | end |
0 commit comments