Skip to content

Make it easier to sample from posterior predictive #1475

@bgroenks96

Description

@bgroenks96

It would be really nice if we could make it easier to sample from the posterior predictive distribution. It's technically possible right now (I think, please check my example below), but it's a bit of a pain.

Consider the linear regression example in the documentation:

# Bayesian linear regression.
@model function linear_regression(x, y)
    # Set variance prior.
    σ₂ ~ truncated(Normal(0, 100), 0, Inf)
    
    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))
    
    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(nfeatures, sqrt(10))
    
    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sqrt(σ₂))
end

The documentation gives the following function for estimating the posterior mean:

# Make a prediction given an input vector.
function prediction(chain, x)
    p = get_params(chain[200:end, :, :])
    targets = p.intercept' .+ x * reduce(hcat, p.coefficients)'
    return vec(mean(targets; dims = 2))
end

But this is less than ideal because we are basically reproducing part of the model in a separate function. It also ignores observation noise.

We can get proper samples from the posterior predictive by modifying the model spec:

# Bayesian linear regression.
@model function linear_regression(x, y; σ₂=missing,intercept=missing,coefficients=missing)
    # Set variance prior.
    σ₂ ~ truncated(Normal(0, 100), 0, Inf)
    
    # Set intercept prior.
    intercept ~ Normal(0, sqrt(3))
    
    # Set the priors on our coefficients.
    nfeatures = size(x, 2)
    coefficients ~ MvNormal(nfeatures, sqrt(10))
    
    # Calculate all the mu terms.
    mu = intercept .+ x * coefficients
    y ~ MvNormal(mu, sqrt(σ₂))
end
chain = ... # assume we have a sampled chain
pp_samples = []
for sample in eachrow(DataFrame(chain))
    coeffs = select(sample, ["coefficients[$i]" for i in 1:nvars]...)
    model = linear_regression(x,missing; σ₂=sample["σ₂"],intercept=sample["intercept"],coefficients=coeffs)
    push!(pp_samples, model())
end

But this is somewhat cumbersome. Could we make the @model macro implicitly define the random variables in the generated function signature? And maybe add some utility function for sampling from the chain?

Please let me know if I'm missing something here! Maybe there is already an easier way to do this that I am not aware of...?

Note that this is related to a suggestion in #638 "posterior predictive checks"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions