@@ -2,7 +2,6 @@ module InferenceTests
22
33using .. Models: gdemo_d, gdemo_default
44using .. NumericalTests: check_gdemo, check_numerical
5- import .. ADUtils
65using Distributions: Bernoulli, Beta, InverseGamma, Normal
76using Distributions: sample
87import DynamicPPL
@@ -17,8 +16,9 @@ import Mooncake
1716using Test: @test , @test_throws , @testset
1817using Turing
1918
20- @testset " Testing inference.jl with $adbackend " for adbackend in ADUtils. adbackends
21- @info " Starting Inference.jl tests with $adbackend "
19+ @testset verbose = true " Testing Inference.jl" begin
20+ @info " Starting Inference.jl tests"
21+
2222 seed = 23
2323
2424 @testset " threaded sampling" begin
@@ -27,12 +27,12 @@ using Turing
2727 model = gdemo_default
2828
2929 samplers = (
30- HMC (0.1 , 7 ; adtype = adbackend ),
30+ HMC (0.1 , 7 ),
3131 PG (10 ),
3232 IS (),
3333 MH (),
34- Gibbs (:s => PG (3 ), :m => HMC (0.4 , 8 ; adtype = adbackend )),
35- Gibbs (:s => HMC (0.1 , 5 ; adtype = adbackend ), :m => ESS ()),
34+ Gibbs (:s => PG (3 ), :m => HMC (0.4 , 8 )),
35+ Gibbs (:s => HMC (0.1 , 5 ), :m => ESS ()),
3636 )
3737 for sampler in samplers
3838 Random. seed! (5 )
@@ -44,7 +44,7 @@ using Turing
4444 @test chain1. value == chain2. value
4545 end
4646
47- # Should also be stable with am explicit RNG
47+ # Should also be stable with an explicit RNG
4848 seed = 5
4949 rng = Random. MersenneTwister (seed)
5050 for sampler in samplers
@@ -61,27 +61,22 @@ using Turing
6161 # Smoke test for default sample call.
6262 @testset " gdemo_default" begin
6363 chain = sample (
64- StableRNG (seed),
65- gdemo_default,
66- HMC (0.1 , 7 ; adtype= adbackend),
67- MCMCThreads (),
68- 1_000 ,
69- 4 ,
64+ StableRNG (seed), gdemo_default, HMC (0.1 , 7 ), MCMCThreads (), 1_000 , 4
7065 )
7166 check_gdemo (chain)
7267
7368 # run sampler: progress logging should be disabled and
7469 # it should return a Chains object
75- sampler = Sampler (HMC (0.1 , 7 ; adtype = adbackend ))
70+ sampler = Sampler (HMC (0.1 , 7 ))
7671 chains = sample (StableRNG (seed), gdemo_default, sampler, MCMCThreads (), 10 , 4 )
7772 @test chains isa MCMCChains. Chains
7873 end
7974 end
8075
8176 @testset " chain save/resume" begin
82- alg1 = HMCDA (1000 , 0.65 , 0.15 ; adtype = adbackend )
77+ alg1 = HMCDA (1000 , 0.65 , 0.15 )
8378 alg2 = PG (20 )
84- alg3 = Gibbs (:s => PG (30 ), :m => HMC (0.2 , 4 ; adtype = adbackend ))
79+ alg3 = Gibbs (:s => PG (30 ), :m => HMC (0.2 , 4 ))
8580
8681 chn1 = sample (StableRNG (seed), gdemo_default, alg1, 10_000 ; save_state= true )
8782 check_gdemo (chn1)
@@ -260,7 +255,7 @@ using Turing
260255
261256 smc = SMC ()
262257 pg = PG (10 )
263- gibbs = Gibbs (:p => HMC (0.2 , 3 ; adtype = adbackend ), :x => PG (10 ))
258+ gibbs = Gibbs (:p => HMC (0.2 , 3 ), :x => PG (10 ))
264259
265260 chn_s = sample (StableRNG (seed), testbb (obs), smc, 200 )
266261 chn_p = sample (StableRNG (seed), testbb (obs), pg, 200 )
@@ -273,22 +268,17 @@ using Turing
273268
274269 @testset " forbid global" begin
275270 xs = [1.5 2.0 ]
276- # xx = 1
277271
278272 @model function fggibbstest (xs)
279273 s ~ InverseGamma (2 , 3 )
280274 m ~ Normal (0 , sqrt (s))
281- # xx ~ Normal(m, sqrt(s)) # this is illegal
282-
283275 for i in 1 : length (xs)
284276 xs[i] ~ Normal (m, sqrt (s))
285- # for xx in xs
286- # xx ~ Normal(m, sqrt(s))
287277 end
288278 return s, m
289279 end
290280
291- gibbs = Gibbs (:s => PG (10 ), :m => HMC (0.4 , 8 ; adtype = adbackend ))
281+ gibbs = Gibbs (:s => PG (10 ), :m => HMC (0.4 , 8 ))
292282 chain = sample (StableRNG (seed), fggibbstest (xs), gibbs, 2 )
293283 end
294284
@@ -353,7 +343,7 @@ using Turing
353343 )
354344 end
355345
356- # TODO (mhauru) What is this testing? Why does it not use the looped-over adbackend?
346+ # TODO (mhauru) What is this testing? Why does it use a different adbackend?
357347 @testset " new interface" begin
358348 obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
359349
@@ -382,9 +372,7 @@ using Turing
382372 end
383373 end
384374
385- chain = sample (
386- StableRNG (seed), noreturn ([1.5 2.0 ]), HMC (0.1 , 10 ; adtype= adbackend), 4000
387- )
375+ chain = sample (StableRNG (seed), noreturn ([1.5 2.0 ]), HMC (0.1 , 10 ), 4000 )
388376 check_numerical (chain, [:s , :m ], [49 / 24 , 7 / 6 ])
389377 end
390378
@@ -415,7 +403,7 @@ using Turing
415403 end
416404
417405 @testset " sample" begin
418- alg = Gibbs (:m => HMC (0.2 , 3 ; adtype = adbackend ), :s => PG (10 ))
406+ alg = Gibbs (:m => HMC (0.2 , 3 ), :s => PG (10 ))
419407 chn = sample (StableRNG (seed), gdemo_default, alg, 10 )
420408 end
421409
@@ -427,7 +415,7 @@ using Turing
427415 return s, m
428416 end
429417
430- alg = HMC (0.01 , 5 ; adtype = adbackend )
418+ alg = HMC (0.01 , 5 )
431419 x = randn (100 )
432420 res = sample (StableRNG (seed), vdemo1 (x), alg, 10 )
433421
@@ -442,7 +430,7 @@ using Turing
442430
443431 # Vector assumptions
444432 N = 10
445- alg = HMC (0.2 , 4 ; adtype = adbackend )
433+ alg = HMC (0.2 , 4 )
446434
447435 @model function vdemo3 ()
448436 x = Vector {Real} (undef, N)
@@ -497,7 +485,7 @@ using Turing
497485 return s, m
498486 end
499487
500- alg = HMC (0.01 , 5 ; adtype = adbackend )
488+ alg = HMC (0.01 , 5 )
501489 x = randn (100 )
502490 res = sample (StableRNG (seed), vdemo1 (x), alg, 10 )
503491
@@ -507,12 +495,12 @@ using Turing
507495 end
508496
509497 D = 2
510- alg = HMC (0.01 , 5 ; adtype = adbackend )
498+ alg = HMC (0.01 , 5 )
511499 res = sample (StableRNG (seed), vdemo2 (randn (D, 100 )), alg, 10 )
512500
513501 # Vector assumptions
514502 N = 10
515- alg = HMC (0.2 , 4 ; adtype = adbackend )
503+ alg = HMC (0.2 , 4 )
516504
517505 @model function vdemo3 ()
518506 x = Vector {Real} (undef, N)
@@ -559,7 +547,7 @@ using Turing
559547
560548 @testset " Type parameters" begin
561549 N = 10
562- alg = HMC (0.01 , 5 ; adtype = adbackend )
550+ alg = HMC (0.01 , 5 )
563551 x = randn (1000 )
564552 @model function vdemo1 (:: Type{T} = Float64) where {T}
565553 x = Vector {T} (undef, N)
0 commit comments