|
32 | 32 | end |
33 | 33 | @testset "Initial parameters" begin |
34 | 34 | # dummy algorithm that just returns initial value and does not perform any sampling |
35 | | - struct OnlyInitAlg end |
| 35 | + abstract type OnlyInitAlg end |
| 36 | + struct OnlyInitAlgDefault <: OnlyInitAlg end |
| 37 | + struct OnlyInitAlgUniform <: OnlyInitAlg end |
36 | 38 | function DynamicPPL.initialstep( |
37 | 39 | rng::Random.AbstractRNG, |
38 | 40 | model::Model, |
39 | | - ::Sampler{OnlyInitAlg}, |
| 41 | + ::Sampler{<:OnlyInitAlg}, |
40 | 42 | vi::AbstractVarInfo; |
41 | 43 | kwargs..., |
42 | 44 | ) |
43 | 45 | return vi, nothing |
44 | 46 | end |
45 | | - DynamicPPL.getspace(::Sampler{OnlyInitAlg}) = () |
| 47 | + DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () |
46 | 48 |
|
47 | | - # model with one variable: initialization p = 0.2 |
48 | | - @model function coinflip() |
49 | | - p ~ Beta(1, 1) |
50 | | - 10 ~ Binomial(25, p) |
51 | | - end |
52 | | - model = coinflip() |
53 | | - sampler = Sampler(OnlyInitAlg()) |
54 | | - lptrue = logpdf(Binomial(25, 0.2), 10) |
55 | | - chain = sample(model, sampler, 1; init_params = 0.2, progress = false) |
56 | | - @test chain[1].metadata.p.vals == [0.2] |
57 | | - @test getlogp(chain[1]) == lptrue |
| 49 | + # initial samplers |
| 50 | + DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() |
| 51 | + @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() |
58 | 52 |
|
59 | | - # parallel sampling |
60 | | - chains = sample( |
61 | | - model, sampler, MCMCThreads(), 1, 10; |
62 | | - init_params = 0.2, progress = false, |
63 | | - ) |
64 | | - for c in chains |
65 | | - @test c[1].metadata.p.vals == [0.2] |
66 | | - @test getlogp(c[1]) == lptrue |
67 | | - end |
| 53 | + for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) |
| 54 | + # model with one variable: initialization p = 0.2 |
| 55 | + @model function coinflip() |
| 56 | + p ~ Beta(1, 1) |
| 57 | + 10 ~ Binomial(25, p) |
| 58 | + end |
| 59 | + model = coinflip() |
| 60 | + sampler = Sampler(alg) |
| 61 | + lptrue = logpdf(Binomial(25, 0.2), 10) |
| 62 | + chain = sample(model, sampler, 1; init_params = 0.2, progress = false) |
| 63 | + @test chain[1].metadata.p.vals == [0.2] |
| 64 | + @test getlogp(chain[1]) == lptrue |
68 | 65 |
|
69 | | - # model with two variables: initialization s = 4, m = -1 |
70 | | - @model function twovars() |
71 | | - s ~ InverseGamma(2, 3) |
72 | | - m ~ Normal(0, sqrt(s)) |
73 | | - end |
74 | | - model = twovars() |
75 | | - lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) |
76 | | - chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) |
77 | | - @test chain[1].metadata.s.vals == [4] |
78 | | - @test chain[1].metadata.m.vals == [-1] |
79 | | - @test getlogp(chain[1]) == lptrue |
| 66 | + # parallel sampling |
| 67 | + chains = sample( |
| 68 | + model, sampler, MCMCThreads(), 1, 10; |
| 69 | + init_params = 0.2, progress = false, |
| 70 | + ) |
| 71 | + for c in chains |
| 72 | + @test c[1].metadata.p.vals == [0.2] |
| 73 | + @test getlogp(c[1]) == lptrue |
| 74 | + end |
80 | 75 |
|
81 | | - # parallel sampling |
82 | | - chains = sample( |
83 | | - model, sampler, MCMCThreads(), 1, 10; |
84 | | - init_params = [4, -1], progress = false, |
85 | | - ) |
86 | | - for c in chains |
87 | | - @test c[1].metadata.s.vals == [4] |
88 | | - @test c[1].metadata.m.vals == [-1] |
89 | | - @test getlogp(c[1]) == lptrue |
90 | | - end |
| 76 | + # model with two variables: initialization s = 4, m = -1 |
| 77 | + @model function twovars() |
| 78 | + s ~ InverseGamma(2, 3) |
| 79 | + m ~ Normal(0, sqrt(s)) |
| 80 | + end |
| 81 | + model = twovars() |
| 82 | + lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) |
| 83 | + chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) |
| 84 | + @test chain[1].metadata.s.vals == [4] |
| 85 | + @test chain[1].metadata.m.vals == [-1] |
| 86 | + @test getlogp(chain[1]) == lptrue |
| 87 | + |
| 88 | + # parallel sampling |
| 89 | + chains = sample( |
| 90 | + model, sampler, MCMCThreads(), 1, 10; |
| 91 | + init_params = [4, -1], progress = false, |
| 92 | + ) |
| 93 | + for c in chains |
| 94 | + @test c[1].metadata.s.vals == [4] |
| 95 | + @test c[1].metadata.m.vals == [-1] |
| 96 | + @test getlogp(c[1]) == lptrue |
| 97 | + end |
91 | 98 |
|
92 | | - # set only m = -1 |
93 | | - chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) |
94 | | - @test !ismissing(chain[1].metadata.s.vals[1]) |
95 | | - @test chain[1].metadata.m.vals == [-1] |
| 99 | + # set only m = -1 |
| 100 | + chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) |
| 101 | + @test !ismissing(chain[1].metadata.s.vals[1]) |
| 102 | + @test chain[1].metadata.m.vals == [-1] |
96 | 103 |
|
97 | | - # parallel sampling |
98 | | - chains = sample( |
99 | | - model, sampler, MCMCThreads(), 1, 10; |
100 | | - init_params = [missing, -1], progress = false, |
101 | | - ) |
102 | | - for c in chains |
103 | | - @test !ismissing(c[1].metadata.s.vals[1]) |
104 | | - @test c[1].metadata.m.vals == [-1] |
| 104 | + # parallel sampling |
| 105 | + chains = sample( |
| 106 | + model, sampler, MCMCThreads(), 1, 10; |
| 107 | + init_params = [missing, -1], progress = false, |
| 108 | + ) |
| 109 | + for c in chains |
| 110 | + @test !ismissing(c[1].metadata.s.vals[1]) |
| 111 | + @test c[1].metadata.m.vals == [-1] |
| 112 | + end |
105 | 113 | end |
106 | 114 | end |
107 | 115 | end |
0 commit comments