-
Notifications
You must be signed in to change notification settings - Fork 230
Description
We are currently not able to compute the log probability p(x_i | theta) for each observation in Turing. Instead, we always compute sum_i log p(x_i, theta) which makes a lot of sense from an inference point of view. However, adding this functionality would allow:
- Model selection (computing the model evidence)
- Prediction using a Turing model (required for MLJ integration)
I started discussing this with @mohamed82008 and we had a few ideas on how to approach this, all of them seem rather hacked in my opinion.
Here is an alternative proposal....
One of the issues (from what I see) is that we currently do not have a VarName for observe statements. However, we could easily extend the compiler by generating those allowing us to pass on a VarName object to each observe statement.
function generate_observe(observation, dist, model_info)
main_body_names = model_info[:main_body_names]
vi = main_body_names[:vi]
sampler = main_body_names[:sampler]
varname = gensym(:varname)
sym, idcs, csym = gensym(:sym), gensym(:idcs), gensym(:csym)
csym_str, indexing, syms = gensym(:csym_str), gensym(:indexing), gensym(:syms)
if observation isa Symbol
varname_expr = quote
$sym, $idcs, $csym = Turing.@VarName $observation
$csym = Symbol($(QuoteNode(model_info[:name])), $csym)
$syms = Symbol[$csym, $(QuoteNode(observation))]
$varname = Turing.VarName($syms, "")
end
else
varname_expr = quote
$sym, $idcs, $csym = Turing.@VarName $observation
$csym_str = string($(QuoteNode(model_info[:name])))*string($csym)
$indexing = foldl(*, $idcs, init = "")
$varname = Turing.VarName(Symbol($csym_str), $sym, $indexing)
end
end
return quote
$varname_expr
isdist = if isa($dist, AbstractVector)
# Check if the right-hand side is a vector of distributions.
all(d -> isa(d, Distribution), $dist)
else
# Check if the right-hand side is a distribution.
isa($dist, Distribution)
end
@assert isdist @error($(wrong_dist_errormsg(@__LINE__)))
Turing.observe($sampler, $dist, $observation, $varname, $vi)
end
endFurther, we would
-
need to change the way we manipulate the
vi.logpfield as this could now be aFloat64in case of aggregation or aVector{Float64}in case of log probability values for each observation, or -
store the logp values for each observation inside the
VarInfo, i.e. similar to the parameter values, and treatlogpthe way we do it now.
The first option would require us to additionally write a tailored sampler that computes only the log pdf and not the log joint. This is easy but maybe unnecessary overhead and would require to re-evaluate the model for each iteration in case of model-selection.
If we go for option 2 (which is similar to what a user can do in Stan) we would store the aggregated log joint in logp and the log pdf in addition in the VarInfo. This additional storing of the logp values for each observation would be disabled by default and could be used by setting a kwarg. In contrast to option 1, this one would be memory intensive if we aim to compute the model evidence which could be prevented by re-evaluting the model for each iteration (similar to option 1).
I think option 2 might be the more convenient option.
(cc'ing @yebai @xukai92 @willtebbutt @ablaom )