diff --git a/HISTORY.md b/HISTORY.md index 29bc56493..38a0baa93 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -61,6 +61,10 @@ The `resume_from=chn` keyword argument to `sample` has been removed; please use **Other changes** +### `predict(model, chain; include_all)` + +The `include_all` keyword argument for `predict` now works even when no RNG is specified (previously it would only work when an RNG was explicitly passed). + ### `setleafcontext(model, context)` This convenience method has been added to quickly modify the leaf context of a model. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7b9322254..7886ad468 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -116,7 +116,19 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - varinfo = DynamicPPL.VarInfo(model) + + # Set up a VarInfo with the right accumulators + varinfo = DynamicPPL.setaccs!!( + DynamicPPL.VarInfo(), + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogJacobianAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), + ), + ) + _, varinfo = DynamicPPL.init!!(model, varinfo) + varinfo = DynamicPPL.typed_varinfo(varinfo) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) @@ -129,7 +141,7 @@ function DynamicPPL.predict( varinfo, DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), ) - vals = DynamicPPL.values_as_in_model(model, false, varinfo) + vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values varname_vals = mapreduce( collect, vcat, @@ -156,6 +168,13 @@ function DynamicPPL.predict( end return chain_result[parameter_names] end +function DynamicPPL.predict( + model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false +) + return DynamicPPL.predict( + DynamicPPL.Random.default_rng(), model, chain; include_all=include_all + ) +end function _predictive_samples_to_arrays(predictive_samples) variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() diff --git a/test/model.jl b/test/model.jl index 7374f73aa..6ba3bca2a 100644 --- a/test/model.jl +++ b/test/model.jl @@ -519,6 +519,23 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")]) end + @testset "include_all=true" begin + inc_predictions = DynamicPPL.predict( + m_lin_reg_test, β_chain; include_all=true + ) + @test Set(keys(inc_predictions)) == + Set([:β, Symbol("y[1]"), Symbol("y[2]")]) + @test inc_predictions[:β] == β_chain[:β] + # check rng is respected + inc_predictions1 = DynamicPPL.predict( + Xoshiro(468), m_lin_reg_test, β_chain; include_all=true + ) + inc_predictions2 = DynamicPPL.predict( + Xoshiro(468), m_lin_reg_test, β_chain; include_all=true + ) + @test all(Array(inc_predictions1) .== Array(inc_predictions2)) + end + @testset "accuracy" begin ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01