-
Notifications
You must be signed in to change notification settings - Fork 37
Yongchao version: extracting parameter values for a model from chain #495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
ef0ac6f
added methods for extracting parameter values for a model from a give…
torfjelde 8f311fe
Merge branch 'master' into torfjelde/extract-model-values-from-chain
yebai 979db2d
added MCMCchains as a dep to docs
torfjelde 792d1bd
Merge branch 'torfjelde/extract-model-values-from-chain' of github.co…
torfjelde 7511e4f
attempt at fixing doctests
torfjelde 1a2dfc7
remove the doctest as it's not working for some reason
torfjelde 642be6c
added docs
torfjelde 49b6b93
Update docs/src/api.md
torfjelde 299087e
added new functions for `varnames_in_chain` and `values_from_chain`.
YongchaoHuang 545f727
cleaned up .
YongchaoHuang 20bdea7
added doctests for the new functions.
YongchaoHuang de5f684
added `test/model_utils.jl`.
YongchaoHuang 9864e27
Merge branch 'master' of https://github.com/TuringLang/DynamicPPL.jl …
YongchaoHuang e6888c8
only new functions are kept.
YongchaoHuang 77c37c9
Apply suggestions from code review
YongchaoHuang 140045b
added `using Turing` in `DynamicPPL.jl`
YongchaoHuang d75a659
update `Project.toml`.
YongchaoHuang 05085f4
applied reviewdog suggestion.
YongchaoHuang db24f4e
removed `using Turing`.
YongchaoHuang f700044
removed `USING tURING`
YongchaoHuang a4d111f
removed Turing dependency in `src/model_utils.jl`
YongchaoHuang 57cd078
removed dependencies in `src/model_utils.jl`
YongchaoHuang cf30820
removed unecessary and improper imports.
YongchaoHuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| { | ||
| "julia.environmentPath": "/home/yongchao/.julia/dev/DynamicPPL/docs" | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,302 @@ | ||
| ### Yong ############################################################## | ||
| # Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? | ||
| #### 1. varname_in_chain #### | ||
| # here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. | ||
| """ | ||
| varname_in_chain(vn::VarName, chain, chain_idx, iteration_idx) | ||
|
|
||
| Return two outputs: | ||
| - first output: logical `true` if `vn` or ANY of `vn_child` is in `chain` at `chain_idx` and `iteration_idx`; | ||
| - second output: a dictionary containing all leaf names of `vn` from the chain, if any. | ||
|
|
||
| # Example | ||
| ```julia | ||
| julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test | ||
|
|
||
| julia> Random.seed!(111) | ||
| MersenneTwister(111) | ||
|
|
||
| julia> @model function test_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| x ~ Normal(m, sqrt(s)) | ||
| end | ||
| test_model (generic function with 2 methods) | ||
|
|
||
| julia> model = test_model(1.5) | ||
|
|
||
| julia> chain = sample(model, NUTS(), 100) | ||
|
|
||
| julia> varname_in_chain(VarName(:s), chain, 1, 1) | ||
| (true, OrderedDict{Symbol, Bool}(:s => 1)) | ||
|
|
||
| julia> varname_in_chain(VarName(:m), chain, 1, 1) | ||
| (true, OrderedDict{Symbol, Bool}(:m => 1)) | ||
|
|
||
| julia> varname_in_chain(VarName(:x), chain, 1, 1) | ||
| (false, OrderedDict{Symbol, Bool}()) | ||
|
|
||
| ``` | ||
| """ | ||
| function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) | ||
| out = OrderedDict{Symbol,Bool}() | ||
| for vn_child in namesingroup(chain, Symbol(vn)) # namesingroup: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl | ||
| # print("\n $vn_child of $vn is in chain") | ||
| out[vn_child] = | ||
| Symbol(vn_child) ∈ names(chain) && | ||
| !ismissing(chain[iteration_idx, Symbol(vn_child), chain_idx]) | ||
| end | ||
| return !isempty(out), out | ||
| end | ||
|
|
||
| #### 2. varnames_in_chain #### | ||
| # we iteratively check whether each of keys(VarInfo(model)) is present in the chain or not using `varname_in_chain`. | ||
| """ | ||
| varnames_in_chain(model:::Model, chain) | ||
| varnames_in_chain(varinfo::VarInfo, chain) | ||
|
|
||
| Return two outputs: | ||
| - first output: logical `true` if ALL variable names in `model`/`varinfo` are in `chain`; | ||
| - second output: a dictionary containing all leaf names of `vn` presented in the chain, if any. | ||
|
|
||
| # Example | ||
| ```julia | ||
| julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test | ||
|
|
||
| julia> Random.seed!(111) | ||
| MersenneTwister(111) | ||
|
|
||
| julia> @model function test_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| x ~ Normal(m, sqrt(s)) | ||
| end | ||
| test_model (generic function with 2 methods | ||
|
|
||
| julia> model = test_model(1.5) | ||
|
|
||
| julia> chain = sample(model, NUTS(), 100) | ||
|
|
||
| julia> varnames_in_chain(model, chain) | ||
| (true, OrderedDict{Any, Any}(:s => OrderedDict{Symbol, Bool}(:s => 1), :m => OrderedDict{Symbol, Bool}(:m => 1))) | ||
|
|
||
| ``` | ||
| """ | ||
| varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) | ||
| function varnames_in_chain(varinfo::VarInfo, chain) | ||
| out_logical = OrderedDict() | ||
| out = OrderedDict() | ||
| for vn in keys(varinfo) | ||
| out_logical[Symbol(vn)], out[Symbol(vn)] = varname_in_chain(vn, chain, 1, 1) # by default, we check the first chain and the first iteration. | ||
| end | ||
| return all(values(out_logical)), out | ||
| end | ||
|
|
||
| #### 3. values_from_chain #### | ||
| """ | ||
| vn_values_from_chain(vn, chain, chain_idx, iteration_idx) | ||
|
|
||
| Return two outputs: | ||
| - first output: logical `true` if `vn` or ANY of its leaves is in `chain` | ||
| - second output: a dictionary containing all leaf names (if any) of `vn` and their values at `chain_idx`, `iteration_idx` from the chain. | ||
|
|
||
| # Example | ||
| ```julia | ||
| julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test | ||
|
|
||
| julia> Random.seed!(111) | ||
| MersenneTwister(111) | ||
|
|
||
| julia> @model function test_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| x ~ Normal(m, sqrt(s)) | ||
| end | ||
| test_model (generic function with 2 methods) | ||
|
|
||
| julia> model = test_model(1.5) | ||
|
|
||
| julia> chain = sample(model, NUTS(), 100) | ||
|
|
||
| julia> vn_values_from_chain(VarName(:s), chain, 1, 1) | ||
| (true, OrderedDict{Any, Any}(:s => 1.385664578516751)) | ||
|
|
||
| julia> vn_values_from_chain(VarName(:m), chain, 1, 1) | ||
| (true, OrderedDict{Any, Any}(:m => 0.9529550916018266)) | ||
|
|
||
| julia> vn_values_from_chain(VarName(:x), chain, 1, 1) | ||
| (false, OrderedDict{Any, Any}()) | ||
|
|
||
| ``` | ||
| """ | ||
| function vn_values_from_chain(vn::VarName, chain, chain_idx, iteration_idx) | ||
| out = OrderedDict() | ||
| # no need to check if varname_in_chain(vn, chain)[1] - if vn is not in chain, then out will be empty. | ||
| for vn_child in namesingroup(chain, Symbol(vn)) | ||
| try | ||
| out[vn_child] = chain[iteration_idx, Symbol(vn_child), chain_idx] | ||
| catch | ||
| println( | ||
| "Error: retrieve value for $vn_child using chain[$iteration_idx, Symbol($vn_child), $chain_idx] not successful!", | ||
| ) | ||
| end | ||
| end | ||
| return !isempty(out), out | ||
| end | ||
|
|
||
| """ | ||
| values_from_chain(model:::Model, chain) | ||
| values_from_chain(varinfo::VarInfo, chain) | ||
|
|
||
| Return one output: | ||
| - a dictionary containing the (leaves_name, value) pair of ALL parameters in `model`/`varinfo` at `chain_idx`, `iteration_idx` from the chain (if ANY of the leaves of `vn` is present in the chain). | ||
|
|
||
| # Example | ||
| ```julia | ||
|
|
||
| julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test | ||
|
|
||
| julia> Random.seed!(111) | ||
| MersenneTwister(111) | ||
|
|
||
| julia> @model function test_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| x ~ Normal(m, sqrt(s)) | ||
| end | ||
|
|
||
| julia> model = test_model(1.5) | ||
|
|
||
| julia> chain = sample(model, NUTS(), 100) | ||
|
|
||
| julia> values_from_chain(model, chain, 1, 1) | ||
| OrderedDict{Any, Any} with 2 entries: | ||
| :s => 1.38566 | ||
| :m => 0.952955 | ||
|
|
||
| ``` | ||
| """ | ||
| function values_from_chain(model::Model, chain, chain_idx, iteration_idx) | ||
| return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) | ||
| end | ||
| function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) | ||
| out = OrderedDict() | ||
| for vn in keys(varinfo) | ||
| _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) | ||
| merge!(out, out_vn) | ||
| end | ||
| return out | ||
| end | ||
|
|
||
| """ | ||
| values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) | ||
|
|
||
| Return one output: | ||
| - a dictionary containing the values of all leaf names of all parameters in `model`/`varinfo` within `chain_idx_range` and `iteration_idx_range``. | ||
|
|
||
| # Example | ||
| ```julia | ||
|
|
||
| julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test | ||
|
|
||
| julia> Random.seed!(111) | ||
| MersenneTwister(111) | ||
|
|
||
| julia> @model function test_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| x ~ Normal(m, sqrt(s)) | ||
| end | ||
|
|
||
| julia> model = test_model(1.5) | ||
|
|
||
| julia> chain = sample(model, NUTS(), 100) | ||
|
|
||
| julia> values_from_chain(model, chain, 1:2, 1:10) | ||
| Error: retrieve value for s using chain[1, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[2, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[3, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[4, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[5, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[6, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[7, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[8, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[9, Symbol(s), 2] not successful! | ||
| Error: retrieve value for s using chain[10, Symbol(s), 2] not successful! | ||
| Error: retrieve value for m using chain[1, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[2, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[3, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[4, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[5, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[6, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[7, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[8, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[9, Symbol(m), 2] not successful! | ||
| Error: retrieve value for m using chain[10, Symbol(m), 2] not successful! | ||
| OrderedDict{Any, Any} with 2 entries: | ||
| "chain_idx_1" => OrderedDict{Any, Any}(:s=>[1.38566, 1.6544, 1.36912, 1.18434, 1.33485, 1.966… | ||
| "chain_idx_2" => OrderedDict{Any, Any}() | ||
|
|
||
| ``` | ||
| """ | ||
| function values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) | ||
| return values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) | ||
| end | ||
| function values_from_chain( | ||
| varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange | ||
| ) | ||
| all_out = OrderedDict() | ||
| for chain_idx in chain_idx_range | ||
| out = OrderedDict() | ||
| for vn in keys(varinfo) | ||
| for iteration_idx in iteration_idx_range | ||
| _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) | ||
| for key in keys(out_vn) | ||
| if haskey(out, key) | ||
| out[key] = vcat(out[key], out_vn[key]) | ||
| else | ||
| out[key] = out_vn[key] | ||
| end | ||
| end | ||
| end | ||
| end | ||
| all_out["chain_idx_" * string(chain_idx)] = out | ||
| end | ||
| return all_out | ||
| end | ||
| function values_from_chain( | ||
| varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::UnitRange | ||
| ) | ||
| return values_from_chain( | ||
| varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range | ||
| ) | ||
| end | ||
| function values_from_chain( | ||
| varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::Int | ||
| ) | ||
| return values_from_chain( | ||
| varinfo, chain, chain_idx_range, iteration_idx_range:iteration_idx_range | ||
| ) | ||
| end | ||
| function values_from_chain( | ||
| varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int | ||
| ) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) | ||
| return values_from_chain( | ||
| varinfo, | ||
| chain, | ||
| chain_idx_range:chain_idx_range, | ||
| iteration_idx_range:iteration_idx_range, | ||
| ) | ||
| end | ||
| # if either chain_idx_range or iteration_idx_range is specified as `nothing`, then all chains will be included. | ||
| function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_idx_range) | ||
| if chain_idx_range === nothing | ||
| print("chain_idx_range is missing!") | ||
| chain_idx_range = 1:size(chain)[3] | ||
| end | ||
| if iteration_idx_range === nothing | ||
| print("iteration_idx_range is missing!") | ||
| iteration_idx_range = 1:size(chain)[1] | ||
| end | ||
| return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) | ||
| end | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶