Skip to content

Commit aa8cc6d

Browse files
committed
Switch to StableRNGs for broken tests.
1 parent 4c388dc commit aa8cc6d

File tree

5 files changed

+31
-30
lines changed

5 files changed

+31
-30
lines changed

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2727
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2828
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2929
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
30+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3031

3132
[compat]
3233
AbstractMCMC = "3.2.1"
@@ -53,4 +54,5 @@ StatsBase = "0.33"
5354
StatsFuns = "0.9.5"
5455
Tracker = "0.2.11"
5556
Zygote = "0.5.4, 0.6"
57+
StableRNGs = "1"
5658
julia = "1.3"

test/contrib/inference/sghmc.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
@test sampler isa Turing.Sampler{<:SGHMC}
1717
end
1818
@numerical_testset "sghmc inference" begin
19-
Random.seed!(54321)
19+
rng = StableRNG(123)
2020

2121
alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5)
22-
chain = sample(gdemo_default, alg, 10_000)
22+
chain = sample(rng, gdemo_default, alg, 10_000)
2323
check_gdemo(chain, atol = 0.1)
2424
end
2525
end
@@ -42,9 +42,9 @@ end
4242
@test sampler isa Turing.Sampler{<:SGLD}
4343
end
4444
@numerical_testset "sgld inference" begin
45-
Random.seed!(12345)
45+
rng = StableRNG(1)
4646

47-
chain = sample(gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 10_000)
47+
chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000)
4848
check_gdemo(chain, atol = 0.2)
4949

5050
# Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh)

test/inference/hmc.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
@testset "hmc.jl" begin
2+
# Set a seed
3+
rng = StableRNG(123)
24
@numerical_testset "constrained bounded" begin
3-
# Set a seed
4-
Random.seed!(5)
5-
65
obs = [0,1,0,1,1,1,1,1,1,1]
76

87
@model constrained_test(obs) = begin
@@ -14,6 +13,7 @@
1413
end
1514

1615
chain = sample(
16+
rng,
1717
constrained_test(obs),
1818
HMC(1.5, 3),# using a large step size (1.5)
1919
1000)
@@ -33,16 +33,16 @@
3333
end
3434

3535
chain = sample(
36+
rng,
3637
constrained_simplex_test(obs12),
3738
HMC(0.75, 2),
3839
1000)
3940

4041
check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015)
4142
end
4243
@numerical_testset "hmc reverse diff" begin
43-
Random.seed!(1)
4444
alg = HMC(0.1, 10)
45-
res = sample(gdemo_default, alg, 4000)
45+
res = sample(rng, gdemo_default, alg, 4000)
4646
check_gdemo(res, rtol=0.1)
4747
end
4848
@turing_testset "matrix support" begin
@@ -53,7 +53,7 @@
5353
model_f = hmcmatrixsup()
5454
n_samples = 1_000
5555
vs = map(1:3) do _
56-
chain = sample(model_f, HMC(0.15, 7), n_samples)
56+
chain = sample(rng, model_f, HMC(0.15, 7), n_samples)
5757
r = reshape(Array(group(chain, :v)), n_samples, 2, 2)
5858
reshape(mean(r; dims = 1), 2, 2)
5959
end
@@ -103,25 +103,24 @@
103103
end
104104

105105
# Sampling
106-
chain = sample(bnn(ts), HMC(0.1, 5), 10)
106+
chain = sample(rng, bnn(ts), HMC(0.1, 5), 10)
107107
end
108108
@numerical_testset "hmcda inference" begin
109-
Random.seed!(12345)
110109
alg1 = HMCDA(1000, 0.8, 0.015)
111110
# alg2 = Gibbs(HMCDA(200, 0.8, 0.35, :m), HMC(0.25, 3, :s))
112111
alg3 = Gibbs(PG(10, :s), HMCDA(200, 0.8, 0.005, :m))
113112
# alg3 = Gibbs(HMC(0.25, 3, :m), PG(30, 3, :s))
114113
# alg3 = PG(50, 2000)
115114

116-
res1 = sample(gdemo_default, alg1, 3000)
115+
res1 = sample(rng, gdemo_default, alg1, 3000)
117116
check_gdemo(res1)
118117

119118
# res2 = sample(gdemo([1.5, 2.0]), alg2)
120119
#
121120
# @test mean(res2[:s]) ≈ 49/24 atol=0.2
122121
# @test mean(res2[:m]) ≈ 7/6 atol=0.2
123122

124-
res3 = sample(gdemo_default, alg3, 2000)
123+
res3 = sample(rng, gdemo_default, alg3, 2000)
125124
check_gdemo(res3)
126125
end
127126

@@ -146,7 +145,7 @@
146145
end
147146
@numerical_testset "nuts inference" begin
148147
alg = NUTS(1000, 0.8)
149-
res = sample(gdemo_default, alg, 6000)
148+
res = sample(rng, gdemo_default, alg, 6000)
150149
check_gdemo(res)
151150
end
152151
@turing_testset "nuts constructor" begin
@@ -165,8 +164,8 @@
165164
@turing_testset "check discard" begin
166165
alg = NUTS(100, 0.8)
167166

168-
c1 = sample(gdemo_default, alg, 500, discard_adapt = true)
169-
c2 = sample(gdemo_default, alg, 500, discard_adapt = false)
167+
c1 = sample(rng, gdemo_default, alg, 500, discard_adapt = true)
168+
c2 = sample(rng, gdemo_default, alg, 500, discard_adapt = false)
170169

171170
@test size(c1, 1) == 500
172171
@test size(c2, 1) == 500
@@ -175,9 +174,9 @@
175174
alg1 = Gibbs(PG(10, :m), NUTS(100, 0.65, :s))
176175
alg2 = Gibbs(PG(10, :m), HMC(0.1, 3, :s))
177176
alg3 = Gibbs(PG(10, :m), HMCDA(100, 0.65, 0.3, :s))
178-
@test sample(gdemo_default, alg1, 300) isa Chains
179-
@test sample(gdemo_default, alg2, 300) isa Chains
180-
@test sample(gdemo_default, alg3, 300) isa Chains
177+
@test sample(rng, gdemo_default, alg1, 300) isa Chains
178+
@test sample(rng, gdemo_default, alg2, 300) isa Chains
179+
@test sample(rng, gdemo_default, alg3, 300) isa Chains
181180
end
182181

183182
@turing_testset "Regression tests" begin
@@ -186,19 +185,19 @@
186185
m = Matrix{T}(undef, 2, 3)
187186
m .~ MvNormal(zeros(2), I)
188187
end
189-
@test sample(mwe1(), HMC(0.2, 4), 1_000) isa Chains
188+
@test sample(rng, mwe1(), HMC(0.2, 4), 1_000) isa Chains
190189

191190
@model function mwe2(::Type{T} = Matrix{Float64}) where T
192191
m = T(undef, 2, 3)
193192
m .~ MvNormal(zeros(2), I)
194193
end
195-
@test sample(mwe2(), HMC(0.2, 4), 1_000) isa Chains
194+
@test sample(rng, mwe2(), HMC(0.2, 4), 1_000) isa Chains
196195

197196
# https://github.com/TuringLang/Turing.jl/issues/1308
198197
@model function mwe3(::Type{T} = Array{Float64}) where T
199198
m = T(undef, 2, 3)
200199
m .~ MvNormal(zeros(2), I)
201200
end
202-
@test sample(mwe3(), HMC(0.2, 4), 1_000) isa Chains
201+
@test sample(rng, mwe3(), HMC(0.2, 4), 1_000) isa Chains
203202
end
204203
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ using LinearAlgebra
2626
using Pkg
2727
using Random
2828
using Test
29+
using StableRNGs
2930

3031
using AdvancedPS: ResampleWithESSThreshold, resample_systematic, resample_multinomial
3132
using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply!

test/stdlib/distributions.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
@testset "distributions.jl" begin
2+
rng = StableRNG(12345)
23
@turing_testset "distributions functions" begin
34
ns = 10
4-
logitp = randn()
5+
logitp = randn(rng)
56
d1 = BinomialLogit(ns, logitp)
67
d2 = Binomial(ns, logistic(logitp))
78
k = 3
89
@test logpdf(d1, k) logpdf(d2, k)
910
end
1011

1112
@turing_testset "distributions functions" begin
12-
Random.seed!(1)
13-
1413
d = OrderedLogistic(-2, [-1, 1])
1514

1615
n = 1_000_000
17-
y = rand(d, n)
16+
y = rand(rng, d, n)
1817
K = length(d.cutpoints) + 1
1918
p = [mean(==(k), y) for k in 1:K] # empirical probs
2019
pmf = [exp(logpdf(d, k)) for k in 1:K]
@@ -31,9 +30,9 @@
3130
end
3231

3332
@numerical_testset "single distribution correctness" begin
34-
Random.seed!(12345)
33+
rng = StableRNG(1)
3534

36-
n_samples = 5_000
35+
n_samples = 10_000
3736
mean_tol = 0.1
3837
var_atol = 1.0
3938
var_tol = 0.5
@@ -113,7 +112,7 @@
113112

114113
@model m() = x ~ dist
115114

116-
chn = sample(m(), HMC(0.05, 20), n_samples)
115+
chn = sample(rng, m(), HMC(0.05, 20), n_samples)
117116

118117
# Numerical tests.
119118
check_dist_numerical(dist,

0 commit comments

Comments
 (0)