@@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false
2929is_typed_varinfo (varinfo:: DynamicPPL.TypedVarInfo ) = true
3030is_typed_varinfo (varinfo:: DynamicPPL.SimpleVarInfo{<:NamedTuple} ) = true
3131
32+ const GDEMO_DEFAULT = DynamicPPL. TestUtils. demo_assume_observe_literal ()
33+
3234@testset " model.jl" begin
3335 @testset " convenience functions" begin
34- model = gdemo_default # defined in test/test_util.jl
36+ model = GDEMO_DEFAULT # defined in test/test_util.jl
3537
3638 # sample from model and extract variables
3739 vi = VarInfo (model)
@@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
5557 @test ljoint ≈ lp
5658
5759 # ### logprior, logjoint, loglikelihood for MCMC chains ####
58- for model in DynamicPPL. TestUtils. DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59- var_info = VarInfo (model)
60- vns = DynamicPPL. TestUtils. varnames (model)
61- syms = unique (DynamicPPL. getsym .(vns))
62-
63- # generate a chain of sample parameter values.
60+ @testset " $(model. f) " for model in DynamicPPL. TestUtils. DEMO_MODELS
6461 N = 200
65- vals_OrderedDict = mapreduce (hcat, 1 : N) do _
66- rand (OrderedDict, model)
67- end
68- vals_mat = mapreduce (hcat, 1 : N) do i
69- [vals_OrderedDict[i][vn] for vn in vns]
70- end
71- i = 1
72- for col in eachcol (vals_mat)
73- col_flattened = []
74- [push! (col_flattened, x... ) for x in col]
75- if i == 1
76- chain_mat = Matrix (reshape (col_flattened, 1 , length (col_flattened)))
77- else
78- chain_mat = vcat (
79- chain_mat, reshape (col_flattened, 1 , length (col_flattened))
80- )
81- end
82- i += 1
83- end
84- chain_mat = convert (Matrix{Float64}, chain_mat)
85-
86- # devise parameter names for chain
87- sample_values_vec = collect (values (vals_OrderedDict[1 ]))
88- symbol_names = []
89- chain_sym_map = Dict ()
90- for k in 1 : length (keys (var_info))
91- vn_parent = keys (var_info)[k]
62+ chain = make_chain_from_prior (model, N)
63+ logpriors = logprior (model, chain)
64+ loglikelihoods = loglikelihood (model, chain)
65+ logjoints = logjoint (model, chain)
66+
67+ # Construct mapping of varname symbols to varname-parent symbols.
68+ # Here, varname_leaves is used to ensure compatibility with the
69+ # variables stored in the chain
70+ var_info = VarInfo (model)
71+ chain_sym_map = Dict {Symbol, Symbol} ()
72+ for vn_parent in keys (var_info)
9273 sym = DynamicPPL. getsym (vn_parent)
93- vn_children = DynamicPPL. varname_leaves (vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
74+ vn_children = DynamicPPL. varname_leaves (vn_parent, var_info[vn_parent])
9475 for vn_child in vn_children
9576 chain_sym_map[Symbol (vn_child)] = sym
96- symbol_names = [symbol_names; Symbol (vn_child)]
9777 end
9878 end
99- chain = Chains (chain_mat, symbol_names)
10079
101- # calculate the pointwise loglikelihoods for the whole chain using the newly written functions
102- logpriors = logprior (model, chain)
103- loglikelihoods = loglikelihood (model, chain)
104- logjoints = logjoint (model, chain)
10580 # compare them with true values
10681 for i in 1 : N
10782 samples_dict = Dict ()
@@ -115,18 +90,31 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
11590 samples = (; samples_dict... )
11691 samples = modify_value_representation (samples) # `modify_value_representation` defined in test/test_util.jl
11792 @test logpriors[i] ≈
118- DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
93+ DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
11994 @test loglikelihoods[i] ≈ DynamicPPL. TestUtils. loglikelihood_true (
12095 model, samples[:s ], samples[:m ]
12196 )
12297 @test logjoints[i] ≈
123- DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
98+ DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
12499 end
125100 end
126101 end
127102
103+ @testset " DynamicPPL#684: threadsafe evaluation with multiple types" begin
104+ @model function multiple_types (x)
105+ ns ~ filldist (Normal (0 , 2.0 ), 3 )
106+ m ~ Uniform (0 , 1 )
107+ return x ~ Normal (m, 1 )
108+ end
109+ model = multiple_types (1 )
110+ chain = make_chain_from_prior (model, 10 )
111+ loglikelihood (model, chain)
112+ logprior (model, chain)
113+ logjoint (model, chain)
114+ end
115+
128116 @testset " rng" begin
129- model = gdemo_default
117+ model = GDEMO_DEFAULT
130118
131119 for sampler in (SampleFromPrior (), SampleFromUniform ())
132120 for i in 1 : 10
@@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
144132 end
145133
146134 @testset " defaults without VarInfo, Sampler, and Context" begin
147- model = gdemo_default
135+ model = GDEMO_DEFAULT
148136
149137 Random. seed! (100 )
150- s, m = model ()
138+ retval = model ()
151139
152140 Random. seed! (100 )
153- @test model (Random. default_rng ()) == (s, m)
141+ retval2 = model (Random. default_rng ())
142+ @test retval2. s == retval. s
143+ @test retval2. m == retval. m
154144 end
155145
156146 @testset " nameof" begin
@@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
184174 end
185175
186176 @testset " Internal methods" begin
187- model = gdemo_default
177+ model = GDEMO_DEFAULT
188178
189179 # sample from model and extract variables
190180 vi = VarInfo (model)
@@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
224214 end
225215
226216 @testset " rand" begin
227- model = gdemo_default
217+ model = GDEMO_DEFAULT
228218
229219 Random. seed! (1776 )
230220 s, m = model ()
@@ -293,10 +283,10 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
293283 # Ensure log-probability computations are implemented.
294284 @test logprior (model, x) ≈ DynamicPPL. TestUtils. logprior_true (model, x... )
295285 @test loglikelihood (model, x) ≈
296- DynamicPPL. TestUtils. loglikelihood_true (model, x... )
286+ DynamicPPL. TestUtils. loglikelihood_true (model, x... )
297287 @test logjoint (model, x) ≈ DynamicPPL. TestUtils. logjoint_true (model, x... )
298288 @test logjoint (model, x) !=
299- DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
289+ DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
300290 # Ensure `varnames` is implemented.
301291 vi = last (
302292 DynamicPPL. evaluate!! (
@@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
309299 end
310300 end
311301
312- @testset " generated_quantities on `LKJCholesky`" begin
302+ @testset " returned() on `LKJCholesky`" begin
313303 n = 10
314304 d = 2
315305 model = DynamicPPL. TestUtils. demo_lkjchol (d)
@@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
333323 )
334324
335325 # Test!
336- results = generated_quantities (model, chain)
326+ results = returned (model, chain)
337327 for (x_true, result) in zip (xs, results)
338328 @test x_true. UL == result. x. UL
339329 end
@@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
352342 info= (varname_to_symbol= vns_to_syms_with_extra,),
353343 )
354344 # Test!
355- results = generated_quantities (model, chain_with_extra)
345+ results = returned (model, chain_with_extra)
356346 for (x_true, result) in zip (xs, results)
357347 @test x_true. UL == result. x. UL
358348 end
0 commit comments