-
Notifications
You must be signed in to change notification settings - Fork 230
Description
It would be really nice if we could make it easier to sample from the posterior predictive distribution. It's technically possible right now (I think, please check my example below), but it's a bit of a pain.
Consider the linear regression example in the documentation:
# Bayesian linear regression.
@model function linear_regression(x, y)
# Set variance prior.
σ₂ ~ truncated(Normal(0, 100), 0, Inf)
# Set intercept prior.
intercept ~ Normal(0, sqrt(3))
# Set the priors on our coefficients.
nfeatures = size(x, 2)
coefficients ~ MvNormal(nfeatures, sqrt(10))
# Calculate all the mu terms.
mu = intercept .+ x * coefficients
y ~ MvNormal(mu, sqrt(σ₂))
endThe documentation gives the following function for estimating the posterior mean:
# Make a prediction given an input vector.
function prediction(chain, x)
p = get_params(chain[200:end, :, :])
targets = p.intercept' .+ x * reduce(hcat, p.coefficients)'
return vec(mean(targets; dims = 2))
endBut this is less than ideal because we are basically reproducing part of the model in a separate function. It also ignores observation noise.
We can get proper samples from the posterior predictive by modifying the model spec:
# Bayesian linear regression.
@model function linear_regression(x, y; σ₂=missing,intercept=missing,coefficients=missing)
# Set variance prior.
σ₂ ~ truncated(Normal(0, 100), 0, Inf)
# Set intercept prior.
intercept ~ Normal(0, sqrt(3))
# Set the priors on our coefficients.
nfeatures = size(x, 2)
coefficients ~ MvNormal(nfeatures, sqrt(10))
# Calculate all the mu terms.
mu = intercept .+ x * coefficients
y ~ MvNormal(mu, sqrt(σ₂))
endchain = ... # assume we have a sampled chain
pp_samples = []
for sample in eachrow(DataFrame(chain))
coeffs = select(sample, ["coefficients[$i]" for i in 1:nvars]...)
model = linear_regression(x,missing; σ₂=sample["σ₂"],intercept=sample["intercept"],coefficients=coeffs)
push!(pp_samples, model())
endBut this is somewhat cumbersome. Could we make the @model macro implicitly define the random variables in the generated function signature? And maybe add some utility function for sampling from the chain?
Please let me know if I'm missing something here! Maybe there is already an easier way to do this that I am not aware of...?
Note that this is related to a suggestion in #638 "posterior predictive checks"