You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The function `sample` is part of the AbstractMCMC interface. As explained in the interface guide, building a a sampling method that can be used by `sample` consists in overloading the structs and functions in `AbstractMCMC`. The interface guide also gives a standalone example of their implementation, [`AdvancedMH.jl`]().
30
+
The function `sample` is part of the AbstractMCMC interface. As explained in the [interface guide](https://turing.ml/dev/docs/for-developers/interface), building a a sampling method that can be used by `sample` consists in overloading the structs and functions in `AbstractMCMC`. The interface guide also gives a standalone example of their implementation, [`AdvancedMH.jl`]().
29
31
30
-
Turing sampling methods (most of which are written [here](https://github.com/TuringLang/Turing.jl/tree/master/src/inference)) also implement `AbstractMCMC`. Turing defines a particular architecture for `AbstractMCMC` implementations, that enables working with models defined by the `@model` macro, and uses DynamicPPL as a backend. The goal of this page is to describe this architecture, and how you would go about implementing your own sampling method in Turing. I don't go into all the details: for instance, I don't address selectors or parallelism.
32
+
Turing sampling methods (most of which are written [here](https://github.com/TuringLang/Turing.jl/tree/master/src/inference)) also implement `AbstractMCMC`. Turing defines a particular architecture for `AbstractMCMC` implementations, that enables working with models defined by the `@model` macro, and uses DynamicPPL as a backend. The goal of this page is to describe this architecture, and how you would go about implementing your own sampling method in Turing, using Importance Sampling as an example. I don't go into all the details: for instance, I don't address selectors or parallelism.
33
+
34
+
First, we explain how Importance Sampling works in the abstract. Consider the model defined in the first code block. Mathematically, it can be written:
35
+
$$
36
+
\begin{align}
37
+
s &\sim \text{InverseGamma}(2, 3) \\
38
+
m &\sim \text{Normal}(0, \sqrt{s}) \\
39
+
x &\sim \text{Normal}(m, \sqrt{s}) \\
40
+
y &\sim \text{Normal}(m, \sqrt{s})
41
+
\end{align}
42
+
$$
43
+
The **latent** variables are $s$ and $m$, the **observed** variables are $x$ and $y$. The model **joint** distribution $p(s,m,x,y)$ decomposes into the **prior** $p(s,m)$ and the **likelihood** $p(x,y \mid s,m)$. Since $x = 1.5$ and $y = 2$ are observed, the goal is to infer the **posterior** distribution $p(s,m \mid x,y)$.
44
+
45
+
Importance Sampling produces independent samples $(s_i, m_i)$ from the prior distribution. It also outputs unnormalized weights $w_i = \frac {p(x,y,s_i,m_i)} {p(s_i, m_i)} = p(x,y \mid s_i, m_i)$ such that the empirical distribution $\frac 1 N \sum\limits_{i =1}^N \frac {w_i} {\sum\limits_{j=1}^N w_j} \delta_{(s_i, m_i)}$ is a good approximation of the posterior.
31
46
32
47
## 1. Define a `Sampler`
33
48
@@ -102,6 +117,23 @@ end
102
117
103
118
### States
104
119
120
+
The `vi` field contains all the important information about sampling: first and foremost, the values of all the samples, but also the distributions from which they are sampled, the names of model parameters, and other metadata. As we will see below, many important steps during sampling correspond to queries or updates to `spl.state.vi`.
121
+
122
+
By default, you can use `SamplerState`, a concrete type defined in `inference/Inference.jl`, which extends `AbstractSamplerState` and has no field except for `vi`:
When doing Importance Sampling, we care not only about the values of the samples but also their weights. We will see below that the weight of each sample is also added to `spl.state.vi`. Moreover, the average $\frac 1 N \sum\limits_{j=1}^N w_i = \frac 1 N \sum\limits_{j=1}^N p(x,y \mid s_i, m_i)$ of the sample weights is a particularly important quantity:
131
+
132
+
* it is used to **normalize** the **empirical approximation** of the posterior distribution (TODO: link to formula)
133
+
* its logarithm is the importance sampling **estimate** of the **log evidence** $\log p(x, y)$
134
+
135
+
To avoid having to compute it over and over again, `is.jl`defines an IS-specific concrete type `ISState` for sampler states, with an additional field `final_logevidence` containing $\log \left( \frac 1 N \sum\limits_{j=1}^N w_i \right)$.
VarInfo contains all the important information about sampling: names of model parameters, the distributions from which they are sampled, the value of the samples, and other metadata.
116
-
117
-
As we will see below, many important steps during sampling correspond to queries or updates to `spl.state.vi`.
118
-
119
-
By default, you can use `SamplerState`, a concrete type extending `AbstractSamplerState` which has no field apart from `vi`.
147
+
The following diagram summarizes the hierarchy presented above.
@@ -146,8 +174,6 @@ A crude summary, which ignores things like parallelism, is the following. `sampl
146
174
147
175
you can of course implement all of these functions, but `AbstractMCMC` as well as Turing also provide default implementations for simple cases.
148
176
149
-
150
-
151
177
## 3. Overload `assume` and `observe`
152
178
153
179
The functions mentioned above, such as `sample_init!`, `step!`, etc., must of course use information about the model in order to generate samples! In particular, these functions may need **samples from distributions** defined in the model, or to **evaluate the density of these distributions** at some values of the corresponding parameters or observations.
@@ -188,24 +214,7 @@ end
188
214
189
215
It simply returns the density (probability, in the discrete case) of the observed value under the distribution `dist`.
190
216
191
-
## 4. Example: Importance Sampling
192
-
193
-
### Quick description of Importance Sampling
194
-
195
-
Consider the model defined in the first code block. Mathematically, it can be written:
196
-
$$
197
-
\begin{align}
198
-
s &\sim \text{InverseGamma}(2, 3) \\
199
-
m &\sim \text{Normal}(0, \sqrt{s}) \\
200
-
x &\sim \text{Normal}(m, \sqrt{s}) \\
201
-
y &\sim \text{Normal}(m, \sqrt{s})
202
-
\end{align}
203
-
$$
204
-
The **latent** variables are $s$ and $m$, the **observed** variables are $x$ and $y$. The model **joint** distribution $p(s,m,x,y)$ decomposes into the **prior** $p(s,m)$ and the **likelihood** $p(x,y \mid s,m)$. Since $x = 1.5$ and $y = 2$ are observed, the goal is to infer the **posterior** distribution $p(s,m \mid x,y)$.
205
-
206
-
Importance Sampling produces samples $(s_i, m_i)$ that are independent and all distributed according to the prior distribution. It also outputs unnormalized weights $w_i = \frac {p(x,y,s_i,m_i)} {p(s_i, m_i)} = p(x,y \mid s_i, m_i)$ such that the empirical distribution $\frac 1 N \sum\limits_{i =1}^N \frac {w_i} {\sum\limits_{j=1}^N w_j} \delta_{(s_i, m_i)}$ is a good approximation of the posterior.
207
-
208
-
### Understanding `is.jl` step by step
217
+
## 4. Summary: Importance Sampling step by step
209
218
210
219
We focus on the AbstractMCMC functions that are overriden in `is.jl` and executed inside `mcmcsample`: `step!`, which is called `n_samples` times, and `sample_end!`, which is executed once after those `n_samples` iterations.
211
220
@@ -220,5 +229,4 @@ We focus on the AbstractMCMC functions that are overriden in `is.jl` and execute
220
229
* the transition's `vi` field is simply `spl.state.vi`
221
230
* the `lp` field contains the likelihood `spl.state.vi.logp[]`
222
231
* When the, `n_samples` iterations are completed, `sample_end!` fills the `final_logevidence` field of `spl.state`
223
-
* the **true log evidence** is $\log p(x, y)$
224
-
* its **importance sampling estimate** is the logarithm of the average of the likelihoods $p(x,y \mid s_i, m_i)$
232
+
* it simply takes the logarithm of the average of the sample weights, using the log weights for numerical stability
0 commit comments