diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..0217c48c2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "julia.environmentPath": "/home/yongchao/.julia/dev/DynamicPPL/docs" +} \ No newline at end of file diff --git a/docs/src/api.md b/docs/src/api.md index 0e4012e02..03fa05d24 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -102,6 +102,12 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). + +```@docs +value_iterator_from_chain +``` + ```@docs NamedDist ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 04b08fb19..51c2c8761 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -120,7 +120,8 @@ export AbstractVarInfo, decondition, # Convenience macros @addlogprob!, - @submodel + @submodel, + value_iterator_from_chain # Reexport using Distributions: loglikelihood @@ -166,5 +167,6 @@ include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") include("logdensityfunction.jl") +include("model_utils.jl") end # module diff --git a/src/model_utils.jl b/src/model_utils.jl new file mode 100644 index 000000000..6710aea05 --- /dev/null +++ b/src/model_utils.jl @@ -0,0 +1,302 @@ +### Yong ############################################################## +# Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? +#### 1. varname_in_chain #### +# here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. +""" + varname_in_chain(vn::VarName, chain, chain_idx, iteration_idx) + +Return two outputs: + - first output: logical `true` if `vn` or ANY of `vn_child` is in `chain` at `chain_idx` and `iteration_idx`; + - second output: a dictionary containing all leaf names of `vn` from the chain, if any. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) + end +test_model (generic function with 2 methods) + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> varname_in_chain(VarName(:s), chain, 1, 1) +(true, OrderedDict{Symbol, Bool}(:s => 1)) + +julia> varname_in_chain(VarName(:m), chain, 1, 1) +(true, OrderedDict{Symbol, Bool}(:m => 1)) + +julia> varname_in_chain(VarName(:x), chain, 1, 1) +(false, OrderedDict{Symbol, Bool}()) + +``` +""" +function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) + out = OrderedDict{Symbol,Bool}() + for vn_child in namesingroup(chain, Symbol(vn)) # namesingroup: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl + # print("\n $vn_child of $vn is in chain") + out[vn_child] = + Symbol(vn_child) ∈ names(chain) && + !ismissing(chain[iteration_idx, Symbol(vn_child), chain_idx]) + end + return !isempty(out), out +end + +#### 2. varnames_in_chain #### +# we iteratively check whether each of keys(VarInfo(model)) is present in the chain or not using `varname_in_chain`. +""" + varnames_in_chain(model:::Model, chain) + varnames_in_chain(varinfo::VarInfo, chain) + +Return two outputs: + - first output: logical `true` if ALL variable names in `model`/`varinfo` are in `chain`; + - second output: a dictionary containing all leaf names of `vn` presented in the chain, if any. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end +test_model (generic function with 2 methods + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> varnames_in_chain(model, chain) +(true, OrderedDict{Any, Any}(:s => OrderedDict{Symbol, Bool}(:s => 1), :m => OrderedDict{Symbol, Bool}(:m => 1))) + +``` +""" +varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) +function varnames_in_chain(varinfo::VarInfo, chain) + out_logical = OrderedDict() + out = OrderedDict() + for vn in keys(varinfo) + out_logical[Symbol(vn)], out[Symbol(vn)] = varname_in_chain(vn, chain, 1, 1) # by default, we check the first chain and the first iteration. + end + return all(values(out_logical)), out +end + +#### 3. values_from_chain #### +""" + vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + +Return two outputs: + - first output: logical `true` if `vn` or ANY of its leaves is in `chain` + - second output: a dictionary containing all leaf names (if any) of `vn` and their values at `chain_idx`, `iteration_idx` from the chain. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end +test_model (generic function with 2 methods) + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> vn_values_from_chain(VarName(:s), chain, 1, 1) +(true, OrderedDict{Any, Any}(:s => 1.385664578516751)) + +julia> vn_values_from_chain(VarName(:m), chain, 1, 1) +(true, OrderedDict{Any, Any}(:m => 0.9529550916018266)) + +julia> vn_values_from_chain(VarName(:x), chain, 1, 1) +(false, OrderedDict{Any, Any}()) + +``` +""" +function vn_values_from_chain(vn::VarName, chain, chain_idx, iteration_idx) + out = OrderedDict() + # no need to check if varname_in_chain(vn, chain)[1] - if vn is not in chain, then out will be empty. + for vn_child in namesingroup(chain, Symbol(vn)) + try + out[vn_child] = chain[iteration_idx, Symbol(vn_child), chain_idx] + catch + println( + "Error: retrieve value for $vn_child using chain[$iteration_idx, Symbol($vn_child), $chain_idx] not successful!", + ) + end + end + return !isempty(out), out +end + +""" + values_from_chain(model:::Model, chain) + values_from_chain(varinfo::VarInfo, chain) + +Return one output: + - a dictionary containing the (leaves_name, value) pair of ALL parameters in `model`/`varinfo` at `chain_idx`, `iteration_idx` from the chain (if ANY of the leaves of `vn` is present in the chain). + +# Example +```julia + +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> values_from_chain(model, chain, 1, 1) +OrderedDict{Any, Any} with 2 entries: + :s => 1.38566 + :m => 0.952955 + +``` +""" +function values_from_chain(model::Model, chain, chain_idx, iteration_idx) + return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) +end +function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + out = OrderedDict() + for vn in keys(varinfo) + _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + merge!(out, out_vn) + end + return out +end + +""" + values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) + +Return one output: + - a dictionary containing the values of all leaf names of all parameters in `model`/`varinfo` within `chain_idx_range` and `iteration_idx_range``. + +# Example +```julia + +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> values_from_chain(model, chain, 1:2, 1:10) +Error: retrieve value for s using chain[1, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[2, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[3, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[4, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[5, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[6, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[7, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[8, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[9, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[10, Symbol(s), 2] not successful! +Error: retrieve value for m using chain[1, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[2, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[3, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[4, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[5, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[6, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[7, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[8, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[9, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[10, Symbol(m), 2] not successful! +OrderedDict{Any, Any} with 2 entries: + "chain_idx_1" => OrderedDict{Any, Any}(:s=>[1.38566, 1.6544, 1.36912, 1.18434, 1.33485, 1.966… + "chain_idx_2" => OrderedDict{Any, Any}() + +``` +""" +function values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) + return values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) +end +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange +) + all_out = OrderedDict() + for chain_idx in chain_idx_range + out = OrderedDict() + for vn in keys(varinfo) + for iteration_idx in iteration_idx_range + _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + for key in keys(out_vn) + if haskey(out, key) + out[key] = vcat(out[key], out_vn[key]) + else + out[key] = out_vn[key] + end + end + end + end + all_out["chain_idx_" * string(chain_idx)] = out + end + return all_out +end +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::UnitRange +) + return values_from_chain( + varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range + ) +end +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::Int +) + return values_from_chain( + varinfo, chain, chain_idx_range, iteration_idx_range:iteration_idx_range + ) +end +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int +) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + return values_from_chain( + varinfo, + chain, + chain_idx_range:chain_idx_range, + iteration_idx_range:iteration_idx_range, + ) +end +# if either chain_idx_range or iteration_idx_range is specified as `nothing`, then all chains will be included. +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_idx_range) + if chain_idx_range === nothing + print("chain_idx_range is missing!") + chain_idx_range = 1:size(chain)[3] + end + if iteration_idx_range === nothing + print("iteration_idx_range is missing!") + iteration_idx_range = 1:size(chain)[1] + end + return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) +end \ No newline at end of file diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..5c490234f --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,125 @@ +#### prepare the models and chains for testing #### +# (1) manually create a chain using MCMCChains - we know what parameter names are in the chain +val = [1 2; 3 4; 5 6; 7 8; 9 10; 11 12; 13 14; 15 16; 17 18; 19 20] +val = Matrix{Int}(val) +chn_1 = Chains(val, [:s, :m]) + +# (2) sample a Turing model to create a chain +@model function gdemo(x) + mu ~ MvNormal([0, 0, 0], [1 0 0; 0 1 0; 0 0 1]) + return x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1]) +end +model_2 = gdemo([0, 0, 0]) # provide an initial value for `x` +chn_2 = sample(model_2, NUTS(), 100) # NB: the parameter names in an MCMCChains can be retrieved using `namechn_2.name_map[:parameters]_map`: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl + +# (3) sample the demo models in DynamicPPL to create a chain - we do not know what parameter names are in the chain beforehand +model_3 = DynamicPPL.TestUtils.DEMO_MODELS[12] +var_info = VarInfo(model_3) +vns = DynamicPPL.TestUtils.varnames(model_3) +# generate a chain. +N = 100 +vals_OrderedDict = mapreduce(hcat, 1:N) do _ + rand(OrderedDict, model_3) +end +vals_mat = mapreduce(hcat, 1:N) do i + [vals_OrderedDict[i][vn] for vn in vns] +end +i = 1 +for col in eachcol(vals_mat) + col_flattened = [] + [push!(col_flattened, x...) for x in col] + if i == 1 + chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) + else + chain_mat = vcat(chain_mat, reshape(col_flattened, 1, length(col_flattened))) + end + i += 1 +end +chain_mat = convert(Matrix{Float64}, chain_mat) +# devise parameter names for chain +sample_values_vec = collect(values(vals_OrderedDict[1])) +symbol_names = [] +chain_sym_map = Dict() +for k in 1:length(keys(var_info)) + vn_parent = keys(var_info)[k] + sym = DynamicPPL.getsym(vn_parent) + vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) + for vn_child in vn_children + chain_sym_map[Symbol(vn_child)] = sym + symbol_names = [symbol_names; Symbol(vn_child)] + end +end +chn_3 = Chains(chain_mat, symbol_names) + +#### test functions in model_utils.jl #### +@testset "model_utils.jl" begin + @testset "varname_in_chain" begin + # chn_1 + outputs = varname_in_chain(VarName(:s), chn_1, 1, 1) + @test outputs[1] == true && outputs[2][:s] == true + outputs = varname_in_chain(VarName(:m), chn_1, 1, 1) + @test outputs[1] == true && outputs[2][:m] == true + outputs = varname_in_chain(VarName(:x), chn_1, 1, 1) + @test outputs[1] == false && isempty(outputs[2]) + + # chn_2 + outputs = varname_in_chain(VarName(:mu), chn_2, 1, 1) + @test outputs[1] == true && !isempty(outputs[2]) && all(values(outputs[2]) .== 1) + + # chn_3 + outputs = varname_in_chain(VarName(:a), chn_3, 1, 1) + @test outputs[1] == false && isempty(outputs[2]) + outputs = varname_in_chain(VarName(symbol_names[1]), chn_3, 1, 1) + @test !isempty(outputs[2]) && all(values(outputs[2]) .== 1) # note: all(values(outputs[2]) .== 1) is not enough - even an empty dictionary has all(values(outputs[2]) .== 1) + end + @testset "varnames_in_chain" begin + outputs = varnames_in_chain(model_2, chn_2) + @test outputs[1] == true && all(values(outputs[2][:mu])) + outputs = varnames_in_chain(model_3, chn_3) + @test outputs[1] == true + end + @testset "vn_values_from_chain" begin + outputs = vn_values_from_chain(VarName(:mu), chn_2, 1, 1) + @test outputs[1] == true && length(values(outputs[2])) == 3 + outputs = vn_values_from_chain(VarName(:s), chn_3, 1, 1) + @test outputs[1] == true && length(values(outputs[2])) == 2 + outputs = vn_values_from_chain(VarName(Symbol("s[:,1][1]")), chn_3, 1, 2) + @test outputs[1] == true + end + @testset "values_from_chain" begin + output = values_from_chain(model_2, chn_2, 1, 1) + @test length(output["chain_idx_1"]) == 3 + output = values_from_chain(model_3, chn_3, 1, 1) + @test length(output["chain_idx_1"]) == 4 + + output = values_from_chain(model_2, chn_2, 1, 1:10) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_3, chn_3, 1, 1:10) + @test length(output["chain_idx_1"]) == 4 && + all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1, 1) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1:1, 1) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1, 1:1) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1:1, 1:1) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + + output = values_from_chain(model_2, chn_2, nothing, nothing) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_3, chn_3, nothing, nothing) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 4 && + all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + end +end \ No newline at end of file diff --git a/test/turing/model.jl b/test/turing/model.jl index fcbdd88a3..599fba21b 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -9,4 +9,19 @@ test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end end + + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = sample(model, Prior(), 10; progress=false) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end end