Skip to content

Commit a13ec14

Browse files
authored
Add a page on predictive distributions (#658)
* Add a page on predictive distributions * Text * More explanation * a word * Changes from review * add notes * Fix data plurality
1 parent 8a756f6 commit a13ec14

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

_quarto.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ website:
8080
- usage/probability-interface/index.qmd
8181
- usage/modifying-logprob/index.qmd
8282
- usage/tracking-extra-quantities/index.qmd
83+
- usage/predictive-distributions/index.qmd
8384
- usage/mode-estimation/index.qmd
8485
- usage/performance-tips/index.qmd
8586
- usage/sampler-visualisation/index.qmd
@@ -209,6 +210,7 @@ usage-external-samplers: usage/external-samplers
209210
usage-mode-estimation: usage/mode-estimation
210211
usage-modifying-logprob: usage/modifying-logprob
211212
usage-performance-tips: usage/performance-tips
213+
usage-predictive-distributions: usage/predictive-distributions
212214
usage-probability-interface: usage/probability-interface
213215
usage-sampler-visualisation: usage/sampler-visualisation
214216
usage-sampling-options: usage/sampling-options
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
---
2+
title: Predictive Distributions
3+
engine: julia
4+
---
5+
6+
```{julia}
7+
#| echo: false
8+
#| output: false
9+
using Pkg;
10+
Pkg.instantiate();
11+
```
12+
13+
Standard MCMC sampling methods return values of the parameters of the model.
14+
However, it is often also useful to generate new data points using the model, given a distribution of the parameters.
15+
Turing.jl allows you to do this using the `predict` function, along with conditioning syntax.
16+
17+
Consider the following simple model, where we observe some normally-distributed data `X` and want to learn about its mean `m`.
18+
19+
```{julia}
20+
using Turing
21+
@model function f(N)
22+
m ~ Normal()
23+
X ~ filldist(Normal(m), N)
24+
end
25+
```
26+
27+
Notice first how we have not specified `X` as an argument to the model.
28+
This allows us to use Turing's conditioning syntax to specify whether we want to provide observed data or not.
29+
30+
::: {.callout-note}
31+
If you want to specify `X` as an argument to the model, then to mark it as being unobserved, you have to instantiate the model again with `X = missing` or `X = fill(missing, N)`.
32+
Whether you use `missing` or `fill(missing, N)` depends on whether `X` is treated as a single distribution (e.g. with `filldist` or `product_distribution`), or as multiple independent distributions (e.g. with `.~` or a for loop over `eachindex(X)`).
33+
This is rather finicky, so we recommend using the current approach: conditioning and deconditioning `X` as a whole should work regardless of how `X` is defined in the model.
34+
:::
35+
36+
```{julia}
37+
# Generate some synthetic data
38+
N = 5
39+
true_m = 3.0
40+
X = rand(Normal(true_m), N)
41+
42+
# Instantiate the model with observed data
43+
model = f(N) | (; X = X)
44+
45+
# Sample from the posterior
46+
chain = sample(model, NUTS(), 1_000; progress=false)
47+
mean(chain[:m])
48+
```
49+
50+
## Posterior predictive distribution
51+
52+
`chain[:m]` now contains samples from the posterior distribution of `m`.
53+
If we use these samples of the parameters to generate new data points, we obtain the *posterior predictive distribution*.
54+
Statistically, this is defined as
55+
56+
$$
57+
p(\tilde{x} | \mathbf{X}) = \int p(\tilde{x} | \theta) p(\theta | \mathbf{X}) d\theta,
58+
$$
59+
60+
where $\tilde{x}$ are the new data which you wish to draw, $\theta$ are the model parameters, and $\mathbf{X}$ are the observed data.
61+
$p(\tilde{x} | \theta)$ is the distribution of the new data given the parameters, which is specified in the Turing.jl model (the `X ~ ...` line); and $p(\theta | \mathbf{X})$ is the posterior distribution, as given by the Markov chain.
62+
63+
To obtain samples of $\tilde{x}$, we need to first remove the observed data from the model (or 'decondition' the model).
64+
This means that when the model is evaluated, it will sample a new value for `X`.
65+
If you don't decondition the model, then `X` will remain fixed to the observed data, and no new samples will be generated.
66+
67+
```{julia}
68+
predictive_model = decondition(model)
69+
```
70+
71+
::: {.callout-tip}
72+
## Selective deconditioning
73+
74+
If you only want to decondition a single variable `X`, you can use `decondition(model, @varname(X))`.
75+
:::
76+
77+
To demonstrate how this deconditioned model can generate new data, we can fix the value of `m` to be its mean and evaluate the model:
78+
79+
```{julia}
80+
predictive_model_with_mean_m = predictive_model | (; m = mean(chain[:m]))
81+
rand(predictive_model_with_mean_m)
82+
```
83+
84+
This has given us a single sample of `X` given the mean value of `m`.
85+
Of course, to take our Bayesian uncertainty into account, we want to use the full posterior distribution of `m`, not just its mean.
86+
To do so, we use `predict`, which _effectively_ does the same as above but for every sample in the chain:
87+
88+
```{julia}
89+
predictive_samples = predict(predictive_model, chain)
90+
```
91+
92+
::: {.callout-tip}
93+
## Reproducibility
94+
95+
`predict`, like many other Julia functions that involve randomness, takes an optional `rng` as its first argument.
96+
This controls the generation of new `X` samples, and makes your results reproducible.
97+
:::
98+
99+
::: {.callout-note}
100+
`predict` returns a Chains object itself, which will only contain the newly predicted variables.
101+
If you want to also retain the original parameters, you can use `predict(rng, predictive_model, chain; include_all=true)`.
102+
:::
103+
104+
We can visualise the predictive distribution by combining all the samples and making a density plot:
105+
106+
```{julia}
107+
using StatsPlots: density, density!, vline!
108+
109+
predicted_X = vcat([predictive_samples[Symbol("X[$i]")] for i in 1:N]...)
110+
density(predicted_X, label="Posterior predictive")
111+
```
112+
113+
Depending on your data, you may naturally want to create different visualisations.
114+
For example, perhaps `X` contains some time-series data, in which case you can plot each prediction individually as a line against time.
115+
116+
## Prior predictive distribution
117+
118+
Alternatively, if we use the prior distribution of the parameters $p(\theta)$, we obtain the *prior predictive distribution*:
119+
120+
$$
121+
p(\tilde{x}) = \int p(\tilde{x} | \theta) p(\theta) d\theta,
122+
$$
123+
124+
In an exactly analogous fashion to above, you could sample from the prior distribution of the conditioned model, and _then_ pass that to `predict`:
125+
126+
```{julia}
127+
prior_params = sample(model, Prior(), 1_000; progress=false)
128+
prior_predictive_samples = predict(predictive_model, prior_params)
129+
```
130+
131+
In fact there is a simpler way: you can directly sample from the deconditioned model, using Turing's `Prior` sampler.
132+
This will, in a single call, generate prior samples for both the parameters as well as the new data.
133+
134+
```{julia}
135+
prior_predictive_samples = sample(predictive_model, Prior(), 1_000; progress=false)
136+
```
137+
138+
We can visualise the prior predictive distribution in the same way as before.
139+
Let's compare the two predictive distributions:
140+
141+
```{julia}
142+
prior_predicted_X = vcat([prior_predictive_samples[Symbol("X[$i]")] for i in 1:N]...)
143+
density(prior_predicted_X, label="Prior predictive")
144+
density!(predicted_X, label="Posterior predictive")
145+
vline!([true_m], label="True mean", linestyle=:dash, color=:black)
146+
```
147+
148+
We can see here that the prior predictive distribution is:
149+
150+
1. Wider than the posterior predictive distribution;
151+
2. Centred on the prior mean of `m` (which is 0), rather than the posterior mean (which is close to the true mean of `3`).
152+
153+
Both of these are because the posterior predictive distribution has been informed by the observed data.

0 commit comments

Comments
 (0)