-
Notifications
You must be signed in to change notification settings - Fork 37
Add InferenceObjects integration #465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0ef245a
de2a321
7880a02
c28aae3
8b0e9b2
2f7b834
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| module DynamicPPLInferenceObjectsExt | ||
|
|
||
| using AbstractPPL: AbstractPPL | ||
| using DimensionalData: DimensionalData, Dimensions, LookupArrays | ||
| using DynamicPPL: DynamicPPL | ||
| using InferenceObjects: InferenceObjects | ||
| using Random: Random | ||
| using StatsBase: StatsBase | ||
|
|
||
| include("utils.jl") | ||
| include("varinfo.jl") | ||
| include("condition.jl") | ||
| include("generated_quantities.jl") | ||
| include("predict.jl") | ||
| include("pointwise_loglikelihoods.jl") | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| function AbstractPPL.condition( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This whole file is type-piracy
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it rather be an extension to AbstractPPL? Then it would not be type piracy (or rather, only the one that extensions were designed for).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would then we also make InferenceObjects a full dependency of AbstractPPL for v1.8 and earlier?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AbstractPPL is supposed to be extremely lightweight (https://github.com/TuringLang/AbstractPPL.jl/blob/main/Project.toml), so I don't think that's an attractive option. Maybe an optional dependency with Requires or a full-blown subpackage would be better (one can avoid loading it in newer Julia versions). |
||
| context::AbstractPPL.AbstractContext, data::InferenceObjects.Dataset | ||
| ) | ||
| return AbstractPPL.condition(context, NamedTuple(data)) | ||
| end | ||
| function AbstractPPL.condition( | ||
| context::AbstractPPL.AbstractContext, data::InferenceObjects.InferenceData | ||
| ) | ||
| return AbstractPPL.condition(context, data.posterior) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| function DynamicPPL.generated_quantities( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since DynamicPPL places no restrictions on what types these can be, and users might have intermediate types that don't fit the InferenceData format, it would be nice to support users specifying an output type. Either that, or we should document the constraints upon the returned objects in a model.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing in particular that should be possible for the user to specify somehow is the variables to include in the chain. Sometimes you might want to |
||
| mod::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs... | ||
| ) | ||
| sample_dims = Dimensions.dims(data, (:draw, :chain)) | ||
| diminds = DimensionalData.DimIndices(sample_dims) | ||
| values = map(diminds) do dims | ||
| DynamicPPL.generated_quantities(mod, data[dims...]) | ||
| end | ||
| coords = merge(coords, dims2coords(sample_dims)) | ||
| return InferenceObjects.convert_to_dataset( | ||
| collect(eachcol(values)); coords=coords, kwargs... | ||
| ) | ||
| end | ||
|
|
||
| function DynamicPPL.generated_quantities( | ||
| mod::DynamicPPL.Model, idata::InferenceObjects.InferenceData; kwargs... | ||
| ) | ||
| new_groups = Dict{Symbol,InferenceObjects.Dataset}() | ||
| for k in (:posterior, :prior) | ||
| if haskey(idata, k) | ||
| data = idata[k] | ||
| new_groups[k] = merge( | ||
| DynamicPPL.generated_quantities(mod, data; kwargs...), data | ||
| ) | ||
| end | ||
| end | ||
| return merge(idata, InferenceObjects.InferenceData(; new_groups...)) | ||
| end | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,35 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| function DynamicPPL.pointwise_loglikelihoods( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| model::DynamicPPL.Model, data::InferenceObjects.Dataset; coords=(;), kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Get the data by executing the model once | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| vi = DynamicPPL.VarInfo(model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| context = DynamicPPL.PointwiseLikelihoodContext(Dict{String,Vector{Float64}}()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| iters = Iterators.product(axes(data, :draw), axes(data, :chain)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (draw, chain) in iters | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Update the values | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| DynamicPPL.setval!(vi, data, draw, chain) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Execute model | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| model(vi, context) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| ndraws = size(data, :draw) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| nchains = size(data, :chain) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: optionally post-process idata to convert index variables like Symbol("y[1]") to Symbol("y") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty important for the results to be useful with ArviZ but is seemingly non-trivial so will wait for a future PR.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do actually have some functionality do perform this now:) There is https://turinglang.org/DynamicPPL.jl/dev/api/#DynamicPPL.value_iterator_from_chain which makes use of this under the hood; in particular, to get the "innermost" Lines 820 to 844 in 1ebe8bc
Using this you can take the |
||||||||||||||||||||||||||||||||||||||||||||||||||||
| loglikelihoods = Dict( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| varname => reshape(logliks, ndraws, nchains) for | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| (varname, logliks) in context.loglikelihoods | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| isempty(loglikelihoods) && return nothing | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain)))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return InferenceObjects.convert_to_dataset( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| loglikelihoods; group=:log_likelihood, coords=coords, kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| function DynamicPPL.pointwise_loglikelihoods( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| log_likelihood = DynamicPPL.pointwise_loglikelihoods(model, data.posterior; kwargs...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| return merge(data, InferenceObjects.InferenceData(; log_likelihood=log_likelihood)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| function StatsBase.predict( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| data::InferenceObjects.Dataset; | ||
| coords=(;), | ||
| kwargs..., | ||
| ) | ||
| spl = DynamicPPL.SampleFromPrior() | ||
| vi = DynamicPPL.VarInfo(model) | ||
| iters = Iterators.product(axes(data, :draw), axes(data, :chain)) | ||
| values = map(iters) do (draw_id, chain_id) | ||
| # Set variables present in `data` and mark those NOT present in data to be resampled. | ||
| DynamicPPL.setval_and_resample!(vi, data, draw_id, chain_id) | ||
| model(rng, vi, spl) | ||
| return map(concretize, DynamicPPL.values_as(vi, NamedTuple)) | ||
| end | ||
| coords = merge(coords, dims2coords(Dimensions.dims(data, (:draw, :chain)))) | ||
| predictions = InferenceObjects.convert_to_dataset( | ||
| collect(eachcol(values)); group=:posterior_predictive, coords=coords, kwargs... | ||
| ) | ||
| pred_keys = filter(∉(keys(data)), keys(predictions)) | ||
| isempty(pred_keys) && return nothing | ||
| return predictions[pred_keys] | ||
| end | ||
| function StatsBase.predict( | ||
| model::DynamicPPL.Model, data::InferenceObjects.Dataset; kwargs... | ||
| ) | ||
| return StatsBase.predict(Random.default_rng(), model, data; kwargs...) | ||
| end | ||
|
|
||
| function StatsBase.predict( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| data::InferenceObjects.InferenceData; | ||
| coords=(;), | ||
| kwargs..., | ||
| ) | ||
| if haskey(data, :observed_data) | ||
| coords = merge(coords, dims2coords(Dimensions.dims(data.observed_data))) | ||
| end | ||
| new_groups = Dict{Symbol,InferenceObjects.Dataset}() | ||
| if haskey(data, :posterior) | ||
| posterior_predictive = StatsBase.predict( | ||
| rng, model, data.posterior; coords=coords, kwargs... | ||
| ) | ||
| if posterior_predictive === nothing | ||
| @warn "No predictions were made based on posterior. Has the model been deconditioned?" | ||
| else | ||
| new_groups[:posterior_predictive] = posterior_predictive | ||
| end | ||
| end | ||
| if haskey(data, :prior) | ||
| prior_predictive = StatsBase.predict( | ||
| rng, model, data.prior; coords=coords, kwargs... | ||
| ) | ||
| if prior_predictive === nothing | ||
| @warn "No predictions were made based on prior. Has the model been deconditioned?" | ||
| else | ||
| new_groups[:prior_predictive] = prior_predictive | ||
| end | ||
| end | ||
| if !(haskey(data, :posterior) || haskey(data, :prior)) | ||
| @warn "No posterior or prior found in InferenceData. Returning unmodified input." | ||
| return data | ||
| end | ||
| return merge(data, InferenceObjects.InferenceData(; new_groups...)) | ||
| end | ||
| function StatsBase.predict( | ||
| model::DynamicPPL.Model, data::InferenceObjects.InferenceData; kwargs... | ||
| ) | ||
| return StatsBase.predict(Random.default_rng(), model, data; kwargs...) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # adapted from MCMCChains | ||
| function isconcretetype_recursive(T) | ||
| return isconcretetype(T) && (eltype(T) === T || isconcretetype_recursive(eltype(T))) | ||
| end | ||
|
|
||
| concretize(x) = x | ||
| function concretize(x::AbstractArray) | ||
| if isconcretetype_recursive(typeof(x)) | ||
| return x | ||
| else | ||
| xnew = map(concretize, x) | ||
| T = mapreduce(typeof, promote_type, xnew; init=Union{}) | ||
| if T <: eltype(xnew) && T !== Union{} | ||
| return convert(AbstractArray{T}, xnew) | ||
| else | ||
| return xnew | ||
| end | ||
| end | ||
| end | ||
|
|
||
| dims2coords(dims) = NamedTuple{Dimensions.dim2key(dims)}(Dimensions.lookup(dims)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| function DynamicPPL.setval!( | ||
| vi::DynamicPPL.VarInfo, data::InferenceObjects.Dataset, draw_id::Int, chain_id::Int | ||
| ) | ||
| return DynamicPPL.setval!(vi, data[draw=draw_id, chain=chain_id]) | ||
| end | ||
|
|
||
| function DynamicPPL.setval_and_resample!( | ||
| vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, | ||
| data::InferenceObjects.Dataset, | ||
| draw_id::Int, | ||
| chain_id::Int, | ||
| ) | ||
| return DynamicPPL.setval_and_resample!(vi, data[draw=draw_id, chain=chain_id]) | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.