Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
225 commits
Select commit Hold shift + click to select a range
8a4b894
extended methods for `logprior`, `loglikelihood`, `logposterior` for …
YongchaoHuang Dec 7, 2022
7fcebc5
accept Github Actions.
YongchaoHuang Dec 8, 2022
63d1970
Update src/logp.jl
YongchaoHuang Dec 8, 2022
3cc4912
Update src/logp.jl
YongchaoHuang Dec 8, 2022
0df18e0
Update src/logp.jl
YongchaoHuang Dec 8, 2022
390265d
Update src/logp.jl
YongchaoHuang Dec 8, 2022
ee329a2
Update src/logp.jl
YongchaoHuang Dec 8, 2022
e7bb1b8
typed `AbstractChains`;
YongchaoHuang Dec 16, 2022
f6ec7e0
re-formatting.
YongchaoHuang Dec 16, 2022
d7fa16f
removed comments to pass formatting test.
YongchaoHuang Dec 16, 2022
5e1b144
Update src/logp.jl
YongchaoHuang Dec 16, 2022
4b114de
Update src/logp.jl
YongchaoHuang Dec 16, 2022
1ca088d
Update src/logp.jl
YongchaoHuang Dec 16, 2022
be54a5a
1. removed the import statements in `lop.jl`;
YongchaoHuang Dec 18, 2022
aa47024
Update src/logp.jl
YongchaoHuang Dec 18, 2022
18a8831
Update src/logp.jl
YongchaoHuang Dec 18, 2022
063eed0
Update src/logp.jl
YongchaoHuang Dec 18, 2022
d29f80d
Update src/logp.jl
YongchaoHuang Dec 19, 2022
0e773ab
Update src/logp.jl
YongchaoHuang Dec 19, 2022
efd0c87
Update src/logp.jl
YongchaoHuang Dec 19, 2022
075ac81
modified src/logp.jl; added test/logp.jl
YongchaoHuang Jan 11, 2023
17aee2b
Update src/logp.jl
YongchaoHuang Jan 11, 2023
d1add25
Update src/logp.jl
YongchaoHuang Jan 11, 2023
9623302
Update src/logp.jl
YongchaoHuang Jan 11, 2023
d31a7c3
Update src/logp.jl
YongchaoHuang Jan 11, 2023
374cfde
Update test/logp.jl
YongchaoHuang Jan 11, 2023
94c3864
Modified Docstrings;
YongchaoHuang Jan 12, 2023
35683c7
Update src/logp.jl
YongchaoHuang Jan 12, 2023
189550c
Update src/logp.jl
YongchaoHuang Jan 12, 2023
20ac6c0
Update src/logp.jl
YongchaoHuang Jan 12, 2023
5303e6f
Update src/logp.jl
YongchaoHuang Jan 12, 2023
34e9383
renamed `chain_logprior`,`chain_loglikelihood',`chain_logposterior' t…
YongchaoHuang Jan 16, 2023
3a86ff0
Merge branch 'master' into yongchao/logp.jl
YongchaoHuang Jan 16, 2023
db984b3
added `include("logdensityfunction.jl")` to `DynamicPPL.jl`
YongchaoHuang Jan 17, 2023
417d3d0
formatted `test/logp.jl`.
YongchaoHuang Jan 17, 2023
94d981b
Update test/logp.jl
YongchaoHuang Jan 17, 2023
d3b8d5a
Update test/logp.jl
YongchaoHuang Jan 17, 2023
254544d
formatted `scr/logp.jl`.
YongchaoHuang Jan 17, 2023
58cf492
Update src/logp.jl
YongchaoHuang Jan 17, 2023
7050e0b
Update src/logp.jl
YongchaoHuang Jan 17, 2023
51b3c49
Update src/logp.jl
YongchaoHuang Jan 17, 2023
6070637
Update src/logp.jl
YongchaoHuang Jan 17, 2023
9137dc8
Update src/logp.jl
YongchaoHuang Jan 17, 2023
ee04a46
Update src/logp.jl
YongchaoHuang Jan 17, 2023
6138b25
Update src/logp.jl
YongchaoHuang Jan 17, 2023
c1e9d09
Update src/logp.jl
YongchaoHuang Jan 17, 2023
60d4e2c
Update src/logp.jl
YongchaoHuang Jan 17, 2023
c54ddd5
formatted `scr/logp.jl`.
YongchaoHuang Jan 17, 2023
f91c4ea
Removed comments.
YongchaoHuang Jan 17, 2023
78cc213
Update src/logp.jl
YongchaoHuang Jan 18, 2023
ef2110e
Update src/logp.jl
YongchaoHuang Jan 18, 2023
262b6c2
Update src/logp.jl
YongchaoHuang Jan 18, 2023
63e486a
Update src/logp.jl
YongchaoHuang Jan 18, 2023
6d9a884
Update src/logp.jl
YongchaoHuang Jan 18, 2023
3e954be
Update src/logp.jl
YongchaoHuang Jan 18, 2023
7ebcf10
Update test/logp.jl
YongchaoHuang Jan 18, 2023
ee82b65
removed redundant methods (NamedTuples and Array inputs).
YongchaoHuang Jan 18, 2023
9adb6e7
added REPL examples to docstrings.
YongchaoHuang Jan 18, 2023
68d76ae
Update src/logp.jl
YongchaoHuang Jan 18, 2023
3244d2f
Update src/logp.jl
YongchaoHuang Jan 18, 2023
4a28c2a
added `start_idx` into `src/logp.jl`; rewrite `test/logp.jl` using `m…
YongchaoHuang Jan 19, 2023
96bfbf1
added `start_idx` to the 3 methods.
YongchaoHuang Jan 19, 2023
8b75a8c
Reduced chainn size in the docstrings example.
YongchaoHuang Jan 19, 2023
97d4e94
applied formatting.
YongchaoHuang Jan 19, 2023
cfc76d2
Update src/logp.jl
YongchaoHuang Jan 19, 2023
5217a91
Update src/logp.jl
YongchaoHuang Jan 19, 2023
f426418
Update test/logp.jl
YongchaoHuang Jan 19, 2023
ac939ae
Update test/logp.jl
YongchaoHuang Jan 19, 2023
9e383da
Update test/logp.jl
YongchaoHuang Jan 19, 2023
ee0b63d
Update src/logp.jl
YongchaoHuang Jan 19, 2023
d4706b5
Update src/logp.jl
YongchaoHuang Jan 19, 2023
398e0d2
Update src/logp.jl
YongchaoHuang Jan 19, 2023
3b95070
Update src/logp.jl
YongchaoHuang Jan 19, 2023
dc8808b
upated signatures in docstrings.
YongchaoHuang Jan 19, 2023
5c05964
applied formatting.
YongchaoHuang Jan 19, 2023
8f40c13
formatted again.
YongchaoHuang Jan 19, 2023
e8e813c
Update src/logp.jl
YongchaoHuang Jan 19, 2023
af6571d
Update src/logp.jl
YongchaoHuang Jan 19, 2023
8799426
Update src/logp.jl
YongchaoHuang Jan 19, 2023
05ae602
fix doctests setup
yebai Jan 20, 2023
a604da9
Update docs/make.jl
YongchaoHuang Jan 20, 2023
abb83fd
Update test/runtests.jl
YongchaoHuang Jan 20, 2023
8ce9422
applied formatting.
YongchaoHuang Jan 20, 2023
560b96b
Formatting.
YongchaoHuang Jan 20, 2023
49032dd
Update src/logp.jl
YongchaoHuang Jan 20, 2023
303f931
Update src/logp.jl
YongchaoHuang Jan 20, 2023
e101303
Update src/logp.jl
YongchaoHuang Jan 20, 2023
3a29f6f
Update src/logp.jl
YongchaoHuang Jan 20, 2023
e9e515f
Update src/logp.jl
YongchaoHuang Jan 20, 2023
63596e2
Update src/logp.jl
YongchaoHuang Jan 20, 2023
2c5ab9e
Fix doc tests again.
yebai Jan 20, 2023
4c62f07
Fixed formatting.
yebai Jan 20, 2023
9673f1a
Merged `logp.jl` into `model.jl`
yebai Jan 20, 2023
ef85267
CompatHelper: bump compat for Turing to 0.24 for package turing, (kee…
github-actions[bot] Jan 20, 2023
bbff92d
CompatHelper: bump compat for Turing to 0.23 for package turing, (kee…
github-actions[bot] Jan 20, 2023
824dcb6
Fixed obsolete `TArray` reference.
yebai Jan 20, 2023
97028a5
Fixed incorrect code.
yebai Jan 20, 2023
895384e
More bugfixes in logp tests.
yebai Jan 20, 2023
28fdf7d
Avoid calling Turing sampler.
yebai Jan 20, 2023
a456147
Apply suggestions from code review
yebai Jan 20, 2023
45c0141
Replace SampleFromPrior with synthetic chain.
yebai Jan 21, 2023
e4558dd
Update test/model.jl
yebai Jan 21, 2023
d8a4d32
Minor bugfix.
yebai Jan 21, 2023
5bf860b
Update src/model.jl
YongchaoHuang Jan 21, 2023
d63c185
Update src/model.jl
YongchaoHuang Jan 21, 2023
b9bae11
Update src/model.jl
YongchaoHuang Jan 21, 2023
2d58d18
Update src/model.jl
YongchaoHuang Jan 21, 2023
a5fd292
Merge branch 'master' into yongchao/logp.jl
yebai Jan 22, 2023
a34d5e4
Update Project.toml
yebai Jan 22, 2023
394967c
Update src/model.jl
YongchaoHuang Jan 22, 2023
6c8d253
Update src/model.jl
YongchaoHuang Jan 22, 2023
a364e3b
Update src/model.jl
YongchaoHuang Jan 22, 2023
1f16544
Update src/model.jl
YongchaoHuang Jan 23, 2023
732a1ed
Update src/model.jl
YongchaoHuang Jan 23, 2023
b925158
Update src/model.jl
YongchaoHuang Jan 23, 2023
5af2d52
Update src/model.jl
YongchaoHuang Jan 23, 2023
364cfc7
Update src/model.jl
YongchaoHuang Jan 23, 2023
8e9ee49
Update src/model.jl
YongchaoHuang Jan 23, 2023
d53a79d
Update src/model.jl
YongchaoHuang Jan 23, 2023
bf05924
Update src/model.jl
YongchaoHuang Jan 23, 2023
376a604
Update src/model.jl
YongchaoHuang Jan 23, 2023
c4da915
Update src/model.jl
YongchaoHuang Jan 23, 2023
221805b
Update src/model.jl
YongchaoHuang Jan 23, 2023
b7ebade
Update src/model.jl
YongchaoHuang Jan 23, 2023
e94f2f3
Update src/model.jl
YongchaoHuang Jan 23, 2023
08c6dc7
Update src/model.jl
YongchaoHuang Jan 23, 2023
01569fc
Update src/model.jl
YongchaoHuang Jan 23, 2023
57a0671
Update src/model.jl
YongchaoHuang Jan 23, 2023
e74b0fe
Update src/model.jl
YongchaoHuang Jan 23, 2023
13b9a7f
Update src/model.jl
YongchaoHuang Jan 23, 2023
0abde5a
Update src/model.jl
YongchaoHuang Jan 23, 2023
825a2d2
Update src/model.jl
YongchaoHuang Jan 23, 2023
78e511b
Update src/model.jl
YongchaoHuang Jan 23, 2023
aea5fdf
Update src/model.jl
YongchaoHuang Jan 23, 2023
c544a06
Update src/model.jl
YongchaoHuang Jan 23, 2023
9e3d260
Update src/model.jl
YongchaoHuang Jan 23, 2023
a0dbb13
Update src/model.jl
YongchaoHuang Jan 23, 2023
b06a374
Update src/model.jl
YongchaoHuang Jan 23, 2023
33fb855
Update src/model.jl
YongchaoHuang Jan 23, 2023
f7f68b8
Update src/model.jl
YongchaoHuang Jan 23, 2023
bf95fb0
Update src/model.jl
YongchaoHuang Jan 23, 2023
88662d1
Update src/model.jl
YongchaoHuang Jan 23, 2023
de3720f
Update src/model.jl
YongchaoHuang Jan 23, 2023
7cf464d
Update src/model.jl
YongchaoHuang Jan 23, 2023
94cff03
Update test/model.jl
YongchaoHuang Jan 23, 2023
a96fc43
Update src/model.jl
YongchaoHuang Jan 23, 2023
322ad7b
Update src/model.jl
YongchaoHuang Jan 23, 2023
b46bb44
Update src/model.jl
YongchaoHuang Jan 23, 2023
c6bda2b
Update src/model.jl
YongchaoHuang Jan 23, 2023
f3c67b1
Added `logprior_true(model,NamedTuple)' and
YongchaoHuang Jan 27, 2023
6e9639b
Update test/model.jl
YongchaoHuang Jan 27, 2023
339ef0d
Apply suggestions from code review
YongchaoHuang Jan 27, 2023
0984762
Fixed missing prefix and imports.
yebai Jan 27, 2023
2df3e65
Move tests into convenience functions.
yebai Jan 27, 2023
37c477c
Update test/model.jl
yebai Jan 27, 2023
a717419
Removed constraints on floating number precision.
yebai Jan 27, 2023
8eeed1c
Fix type constraint again.
yebai Jan 27, 2023
f34dd29
Apply suggestions from code review
YongchaoHuang Jan 27, 2023
4400f48
1. removed `StableRNGs`;
YongchaoHuang Jan 27, 2023
9a6ff47
Merge branch 'master' into yongchao/logp.jl
yebai Jan 27, 2023
33e5ee5
Bugfix.
yebai Jan 27, 2023
a5d3671
Import TestUtils -- it is not exported by DPPL.
yebai Jan 27, 2023
be2c2ed
Specialise on model type.
yebai Jan 27, 2023
0a93a21
Improve test.
yebai Jan 28, 2023
2f62fad
Update src/test_utils.jl
yebai Jan 28, 2023
1a8fa89
Apply suggestions from code review
storopoli Jan 29, 2023
6945b4e
Apply suggestions from code review
YongchaoHuang Jan 30, 2023
d36b9ca
Apply suggestions from code review
YongchaoHuang Jan 30, 2023
fd225a5
midified the way chain value was extracted in all 3 methods.
YongchaoHuang Feb 7, 2023
f28e5e1
Merge branch 'yongchao/logp.jl' of github.com:TuringLang/DynamicPPL.j…
YongchaoHuang Feb 7, 2023
c46cf56
Apply suggestions from code review
YongchaoHuang Feb 7, 2023
6fc5738
Apply suggestions from code review
YongchaoHuang Feb 7, 2023
585e896
Update src/model.jl
yebai Feb 7, 2023
11bef7f
Update src/model.jl
yebai Feb 7, 2023
c6eb9c2
Update src/utils.jl
yebai Feb 7, 2023
3249919
rewrote the tests (mainly the way extracting parameter values from ch…
YongchaoHuang Feb 12, 2023
9f64941
removed BangBang from doctest setup; fixed imcomplete end in test/mod…
YongchaoHuang Feb 12, 2023
2ea8de8
Apply suggestions from code review
YongchaoHuang Feb 12, 2023
0ad2fc5
fixed a naming bug (argvals_mat_dict) in src/model.jl.
YongchaoHuang Feb 12, 2023
c3c7a6a
fixed a typo - missing `var_info`.
YongchaoHuang Feb 12, 2023
3dbbdae
Update test/model.jl
YongchaoHuang Feb 12, 2023
3c617d8
Apply suggestions from code review
YongchaoHuang Feb 13, 2023
bf86218
Apply suggestions from code review
YongchaoHuang Feb 13, 2023
a150215
Explicitly added `using Distributions` in doctests; Accepted suggesti…
YongchaoHuang Feb 14, 2023
0bbf948
Apply suggestions from code review
YongchaoHuang Feb 14, 2023
ab624e3
Merge branch 'master' into yongchao/logp.jl
yebai Feb 24, 2023
7b7f13c
rm unnecessary deps
yebai Feb 24, 2023
b6a6097
replace contains with subsumes.
yebai Feb 24, 2023
276c76c
rm redundant deps in docs build script.
yebai Feb 24, 2023
c7717b0
Update test/model.jl
yebai Feb 24, 2023
23e2ab1
Fix format.
yebai Feb 24, 2023
84e01b9
Replaced `subsumes` by `contains`.
YongchaoHuang Apr 3, 2023
1dd81f9
Update test/model.jl
YongchaoHuang Apr 3, 2023
9b37927
replaced 'contains' by a new, temporary method 'subsumes_sym', just f…
YongchaoHuang Apr 12, 2023
da0f39c
Update test/model.jl
YongchaoHuang Apr 12, 2023
0180426
Update test/model.jl
YongchaoHuang Apr 12, 2023
e015429
modified `/test/model.jl`:
YongchaoHuang Jun 22, 2023
b522358
Update test/model.jl
YongchaoHuang Jun 22, 2023
32ee32f
Fixed a mistake in `modify_value_representation`.
YongchaoHuang Jun 22, 2023
c29554b
fixed `gdemo_default`
YongchaoHuang Jun 22, 2023
c6edb50
assigned `model=gdemo_default`.
YongchaoHuang Jun 22, 2023
02b9a43
src/model.jl: added `DynamicPPL.` to `logprior` and `logjoint`.
YongchaoHuang Jun 22, 2023
4e1f424
commented out `gdemo_d()` as a trial test.
YongchaoHuang Jun 22, 2023
9a0404b
used `Symbol(vn_child)` as keys in `chain_sym_map`.
YongchaoHuang Jun 23, 2023
a8af236
Update test/model.jl
YongchaoHuang Jun 23, 2023
aa5737d
Merge branch 'master' into yongchao/logp.jl
YongchaoHuang Jun 23, 2023
c57f276
explicitly loaded `varname_leaves` and `values_from_chain`.
YongchaoHuang Jun 23, 2023
02e3610
added `print` statements for temporary diagnosis purpose.
YongchaoHuang Jun 23, 2023
fa2fb6a
added 'print` statements for temporary diagnostics purpose.
YongchaoHuang Jun 23, 2023
57585ff
Update test/model.jl
YongchaoHuang Jun 23, 2023
80f23f9
diagnostics again.
YongchaoHuang Jun 23, 2023
673fb6a
Removed some `print` statements as it's working.
YongchaoHuang Jun 24, 2023
e515943
Update test/model.jl
YongchaoHuang Jul 4, 2023
d4edf58
Update test/model.jl
YongchaoHuang Jul 4, 2023
3e4770b
Update test/model.jl
YongchaoHuang Jul 4, 2023
5565267
Update test/model.jl
YongchaoHuang Jul 4, 2023
6e8c848
1. moved helper functions to `test_util.jl`; 2. re-wrote the way `cha…
YongchaoHuang Jul 6, 2023
e109bbf
Update test/model.jl
YongchaoHuang Jul 6, 2023
6007a90
formatting.
YongchaoHuang Jul 6, 2023
7617237
formatting.
YongchaoHuang Jul 6, 2023
3b1d9e7
Merge branch 'master' into yongchao/logp.jl
yebai Jul 21, 2023
dbfc171
Update utils.jl
yebai Jul 21, 2023
7c6888d
Update test_util.jl
yebai Jul 21, 2023
1397dfc
Update Project.toml
yebai Jul 21, 2023
ac8d9e5
replaced 'varname_leaves' by 'DynamicPPL.varname_leaves'.
YongchaoHuang Jul 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.7"
version = "0.23.8"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -17,3 +18,4 @@ LogDensityProblems = "2"
MLUtils = "0.3, 0.4"
Setfield = "0.7.1, 0.8, 1"
StableRNGs = "1"
MCMCChains = "5"
4 changes: 1 addition & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ using DynamicPPL: AbstractPPL
using Distributions

# Doctest setup
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
)
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)

makedocs(;
sitename="DynamicPPL",
Expand Down
108 changes: 108 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,42 @@ function logjoint(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, DefaultContext())))
end

"""
logjoint(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
end

"""
logprior(model::Model, varinfo::AbstractVarInfo)

Expand All @@ -1070,6 +1106,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, PriorContext())))
end

"""
logprior(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> logprior(demo_model([1., 2.]), chain);
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
end

"""
loglikelihood(model::Model, varinfo::AbstractVarInfo)

Expand All @@ -1081,6 +1153,42 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext())))
end

"""
loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)

Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.

# Examples

```jldoctest
julia> using MCMCChains, Distributions

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;

julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);

julia> loglikelihood(demo_model([1., 2.]), chain);
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
loglikelihood(model, argvals_dict)
end
end

"""
generated_quantities(model::Model, chain::AbstractChains)

Expand Down
73 changes: 72 additions & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

@testset "model.jl" begin
@testset "convenience functions" begin
model = gdemo_default
model = gdemo_default # defined in test/test_util.jl

# sample from model and extract variables
vi = VarInfo(model)
Expand All @@ -49,6 +49,77 @@ end
ljoint = logjoint(model, vi)
@test ljoint ≈ lprior + llikelihood
@test ljoint ≈ lp

#### logprior, logjoint, loglikelihood for MCMC chains ####
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
var_info = VarInfo(model)
vns = DynamicPPL.TestUtils.varnames(model)
syms = unique(DynamicPPL.getsym.(vns))

# generate a chain of sample parameter values.
N = 200
vals_OrderedDict = mapreduce(hcat, 1:N) do _
rand(OrderedDict, model)
end
vals_mat = mapreduce(hcat, 1:N) do i
[vals_OrderedDict[i][vn] for vn in vns]
end
i = 1
for col in eachcol(vals_mat)
col_flattened = []
[push!(col_flattened, x...) for x in col]
if i == 1
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
else
chain_mat = vcat(
chain_mat, reshape(col_flattened, 1, length(col_flattened))
)
end
i += 1
end
chain_mat = convert(Matrix{Float64}, chain_mat)

# devise parameter names for chain
sample_values_vec = collect(values(vals_OrderedDict[1]))
symbol_names = []
chain_sym_map = Dict()
for k in 1:length(keys(var_info))
vn_parent = keys(var_info)[k]
sym = DynamicPPL.getsym(vn_parent)
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
for vn_child in vn_children
chain_sym_map[Symbol(vn_child)] = sym
symbol_names = [symbol_names; Symbol(vn_child)]
end
end
chain = Chains(chain_mat, symbol_names)

# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
logpriors = logprior(model, chain)
loglikelihoods = loglikelihood(model, chain)
logjoints = logjoint(model, chain)
# compare them with true values
for i in 1:N
samples_dict = Dict()
for chain_key in keys(chain)
value = chain[i, chain_key, 1]
key = chain_sym_map[chain_key]
existing_value = get(samples_dict, key, Float64[])
push!(existing_value, value)
samples_dict[key] = existing_value
end
samples = (; samples_dict...)
samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl
@test logpriors[i] ≈
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
@test loglikelihoods[i] ≈ DynamicPPL.TestUtils.loglikelihood_true(
model, samples[:s], samples[:m]
)
@test logjoints[i] ≈
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
end
println("\n model $(model) passed !!! \n")
end
end

@testset "rng" begin
Expand Down
15 changes: 15 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ short_varinfo_name(::TypedVarInfo) = "TypedVarInfo"
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"

# convenient functions for testing model.jl
# function to modify the representation of values based on their length
function modify_value_representation(nt::NamedTuple)
modified_nt = NamedTuple()
for (key, value) in zip(keys(nt), values(nt))
if length(value) == 1 # Scalar value
modified_value = value[1]
else # Non-scalar value
modified_value = value
end
modified_nt = merge(modified_nt, (key => modified_value,))
end
return modified_nt
end