@@ -55,50 +55,8 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
5555 @test ljoint ≈ lp
5656
5757 # ### logprior, logjoint, loglikelihood for MCMC chains ####
58- for model in DynamicPPL. TestUtils. DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59- var_info = VarInfo (model)
60- vns = DynamicPPL. TestUtils. varnames (model)
61- syms = unique (DynamicPPL. getsym .(vns))
62-
63- # generate a chain of sample parameter values.
64- N = 200
65- vals_OrderedDict = mapreduce (hcat, 1 : N) do _
66- rand (OrderedDict, model)
67- end
68- vals_mat = mapreduce (hcat, 1 : N) do i
69- [vals_OrderedDict[i][vn] for vn in vns]
70- end
71- i = 1
72- for col in eachcol (vals_mat)
73- col_flattened = []
74- [push! (col_flattened, x... ) for x in col]
75- if i == 1
76- chain_mat = Matrix (reshape (col_flattened, 1 , length (col_flattened)))
77- else
78- chain_mat = vcat (
79- chain_mat, reshape (col_flattened, 1 , length (col_flattened))
80- )
81- end
82- i += 1
83- end
84- chain_mat = convert (Matrix{Float64}, chain_mat)
85-
86- # devise parameter names for chain
87- sample_values_vec = collect (values (vals_OrderedDict[1 ]))
88- symbol_names = []
89- chain_sym_map = Dict ()
90- for k in 1 : length (keys (var_info))
91- vn_parent = keys (var_info)[k]
92- sym = DynamicPPL. getsym (vn_parent)
93- vn_children = DynamicPPL. varname_leaves (vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
94- for vn_child in vn_children
95- chain_sym_map[Symbol (vn_child)] = sym
96- symbol_names = [symbol_names; Symbol (vn_child)]
97- end
98- end
99- chain = Chains (chain_mat, symbol_names)
100-
101- # calculate the pointwise loglikelihoods for the whole chain using the newly written functions
58+ for model in DynamicPPL. TestUtils. DEMO_MODELS
59+ chain = make_chain_from_prior (model, 200 )
10260 logpriors = logprior (model, chain)
10361 loglikelihoods = loglikelihood (model, chain)
10462 logjoints = logjoint (model, chain)
@@ -125,6 +83,19 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
12583 end
12684 end
12785
86+ @testset " DynamicPPL#684: threadsafe evaluation with multiple types" begin
87+ @model function multiple_types (x)
88+ ns ~ filldist (Normal (0 , 2.0 ), 3 )
89+ m ~ Uniform (0 , 1 )
90+ return x ~ Normal (m, 1 )
91+ end
92+ model = multiple_types (1 )
93+ chain = make_chain_from_prior (model, 10 )
94+ loglikelihood (model, chain)
95+ logprior (model, chain)
96+ logjoint (model, chain)
97+ end
98+
12899 @testset " rng" begin
129100 model = gdemo_default
130101
0 commit comments