Skip to content

Commit bb1bcef

Browse files
committed
Test with Tapir
1 parent 29a1342 commit bb1bcef

File tree

8 files changed

+22
-12
lines changed

8 files changed

+22
-12
lines changed

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3030
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3131
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3232
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
33+
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
3334
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3536
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -67,6 +68,7 @@ SpecialFunctions = "0.10.3, 1, 2"
6768
StableRNGs = "1"
6869
StatsBase = "0.33, 0.34"
6970
StatsFuns = "0.9.5, 1"
71+
Tapir = "0.2.24"
7072
TimerOutputs = "0.5"
7173
Tracker = "0.2.11"
7274
Zygote = "0.5.4, 0.6"

test/mcmc/Inference.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module InferenceTests
22

33
using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5+
using ..ADUtils: adbackends
56
using Distributions: Bernoulli, Beta, InverseGamma, Normal
67
using Distributions: sample
78
import DynamicPPL
@@ -11,10 +12,11 @@ using LinearAlgebra: I
1112
import MCMCChains
1213
import Random
1314
import ReverseDiff
15+
import Tapir
1416
using Test: @test, @test_throws, @testset
1517
using Turing
1618

17-
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
19+
@testset "Testing inference.jl with $adbackend" for adbackend in adbackends
1820
# Only test threading if 1.3+.
1921
if VERSION > v"1.2"
2022
@testset "threaded sampling" begin

test/mcmc/abstractmcmc.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AbstractMCMCTests
22

3+
using ..ADUtils: adbackends
34
using AdvancedMH: AdvancedMH
45
using Distributions: sample
56
using Distributions.FillArrays: Zeros
@@ -11,6 +12,7 @@ using LogDensityProblemsAD: LogDensityProblemsAD
1112
using Random: Random
1213
using ReverseDiff: ReverseDiff
1314
using StableRNGs: StableRNG
15+
import Tapir
1416
using Test: @test, @test_throws, @testset
1517
using Turing
1618
using Turing.Inference: AdvancedHMC
@@ -112,8 +114,7 @@ end
112114

113115
@testset "External samplers" begin
114116
@testset "AdvancedHMC.jl" begin
115-
# Try a few different AD backends.
116-
@testset "adtype=$adtype" for adtype in [AutoForwardDiff(), AutoReverseDiff()]
117+
@testset "adtype=$adtype" for adtype in adbackends
117118
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
118119
# Need some functionality to initialize the sampler.
119120
# TODO: Remove this once the constructors in the respective packages become "lazy".

test/mcmc/gibbs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@ module GibbsTests
22

33
using ..Models: MoGtest_default, gdemo, gdemo_default
44
using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical
5+
using ..ADUtils: adbackends
56
using Distributions: InverseGamma, Normal
67
using Distributions: sample
78
using ForwardDiff: ForwardDiff
89
using Random: Random
910
using ReverseDiff: ReverseDiff
11+
import Tapir
1012
using Test: @test, @testset
1113
using Turing
1214
using Turing: Inference
1315
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
1416

15-
@testset "Testing gibbs.jl with $adbackend" for adbackend in (
16-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
17-
)
17+
@testset "Testing gibbs.jl with $adbackend" for adbackend in adbackends
1818
@testset "gibbs constructor" begin
1919
N = 500
2020
s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend))

test/mcmc/gibbs_conditional.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GibbsConditionalTests
22

33
using ..Models: gdemo, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5+
using ..ADUtils: adbackends
56
using Clustering: Clustering
67
using Distributions: Categorical, InverseGamma, Normal, sample
78
using ForwardDiff: ForwardDiff
@@ -11,12 +12,11 @@ using ReverseDiff: ReverseDiff
1112
using StableRNGs: StableRNG
1213
using StatsBase: counts
1314
using StatsFuns: StatsFuns
15+
import Tapir
1416
using Test: @test, @testset
1517
using Turing
1618

17-
@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
18-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
19-
)
19+
@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in adbackends
2020
Random.seed!(1000)
2121
rng = StableRNG(123)
2222

test/mcmc/hmc.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module HMCTests
33
using ..Models: gdemo_default
44
#using ..Models: gdemo
55
using ..NumericalTests: check_gdemo, check_numerical
6+
using ..ADUtils: adbackends
67
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
78
import DynamicPPL
89
using DynamicPPL: Sampler
@@ -13,10 +14,11 @@ using LinearAlgebra: I, dot, vec
1314
import Random
1415
using StableRNGs: StableRNG
1516
using StatsFuns: logistic
17+
import Tapir
1618
using Test: @test, @test_logs, @testset
1719
using Turing
1820

19-
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
21+
@testset "Testing hmc.jl with $adbackend" for adbackend in adbackend
2022
# Set a seed
2123
rng = StableRNG(123)
2224
@testset "constrained bounded" begin

test/mcmc/sghmc.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ module SGHMCTests
22

33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
5+
using ..ADUtils: adbackends
56
using Distributions: sample
67
import ForwardDiff
78
using LinearAlgebra: dot
89
import ReverseDiff
910
using StableRNGs: StableRNG
11+
import Tapir
1012
using Test: @test, @testset
1113
using Turing
1214

13-
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
15+
@testset "Testing sghmc.jl with $adbackend" for adbackend in adbackends
1416
@testset "sghmc constructor" begin
1517
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
1618
@test alg isa SGHMC
@@ -36,7 +38,7 @@ using Turing
3638
end
3739
end
3840

39-
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
41+
@testset "Testing sgld.jl with $adbackend" for adbackend in adbackends
4042
@testset "sgld constructor" begin
4143
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
4244
@test alg isa SGLD

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Turing
77

88
include(pkgdir(Turing) * "/test/test_utils/models.jl")
99
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
10+
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")
1011

1112
Turing.setprogress!(false)
1213

0 commit comments

Comments
 (0)