Skip to content

Commit 634458d

Browse files
committed
move IS presentation to top, add info about ISState
1 parent f95d7a2 commit 634458d

File tree

1 file changed

+38
-30
lines changed

1 file changed

+38
-30
lines changed

docs/src/for-developers/how_turing_implements_abstractmcmc.md

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
Prerequisite: [Interface guide](https://turing.ml/dev/docs/for-developers/interface).
1010

11-
Consider the following code:
11+
## Introduction
12+
13+
Consider the following Turing, code block:
1214

1315
```julia
1416
@model function gdemo(x, y)
@@ -25,9 +27,22 @@ n_samples = 1000
2527
chn = sample(mod, alg, n_samples)
2628
```
2729

28-
The function `sample` is part of the AbstractMCMC interface. As explained in the interface guide, building a a sampling method that can be used by `sample` consists in overloading the structs and functions in `AbstractMCMC`. The interface guide also gives a standalone example of their implementation, [`AdvancedMH.jl`]().
30+
The function `sample` is part of the AbstractMCMC interface. As explained in the [interface guide](https://turing.ml/dev/docs/for-developers/interface), building a a sampling method that can be used by `sample` consists in overloading the structs and functions in `AbstractMCMC`. The interface guide also gives a standalone example of their implementation, [`AdvancedMH.jl`]().
2931

30-
Turing sampling methods (most of which are written [here](https://github.com/TuringLang/Turing.jl/tree/master/src/inference)) also implement `AbstractMCMC`. Turing defines a particular architecture for `AbstractMCMC` implementations, that enables working with models defined by the `@model` macro, and uses DynamicPPL as a backend. The goal of this page is to describe this architecture, and how you would go about implementing your own sampling method in Turing. I don't go into all the details: for instance, I don't address selectors or parallelism.
32+
Turing sampling methods (most of which are written [here](https://github.com/TuringLang/Turing.jl/tree/master/src/inference)) also implement `AbstractMCMC`. Turing defines a particular architecture for `AbstractMCMC` implementations, that enables working with models defined by the `@model` macro, and uses DynamicPPL as a backend. The goal of this page is to describe this architecture, and how you would go about implementing your own sampling method in Turing, using Importance Sampling as an example. I don't go into all the details: for instance, I don't address selectors or parallelism.
33+
34+
First, we explain how Importance Sampling works in the abstract. Consider the model defined in the first code block. Mathematically, it can be written:
35+
$$
36+
\begin{align}
37+
s &\sim \text{InverseGamma}(2, 3) \\
38+
m &\sim \text{Normal}(0, \sqrt{s}) \\
39+
x &\sim \text{Normal}(m, \sqrt{s}) \\
40+
y &\sim \text{Normal}(m, \sqrt{s})
41+
\end{align}
42+
$$
43+
The **latent** variables are $s$ and $m$, the **observed** variables are $x$ and $y$. The model **joint** distribution $p(s,m,x,y)$ decomposes into the **prior** $p(s,m)$ and the **likelihood** $p(x,y \mid s,m)$. Since $x = 1.5$ and $y = 2$ are observed, the goal is to infer the **posterior** distribution $p(s,m \mid x,y)$.
44+
45+
Importance Sampling produces independent samples $(s_i, m_i)$ from the prior distribution. It also outputs unnormalized weights $w_i = \frac {p(x,y,s_i,m_i)} {p(s_i, m_i)} = p(x,y \mid s_i, m_i)$ such that the empirical distribution $\frac 1 N \sum\limits_{i =1}^N \frac {w_i} {\sum\limits_{j=1}^N w_j} \delta_{(s_i, m_i)}$ is a good approximation of the posterior.
3146

3247
## 1. Define a `Sampler`
3348

@@ -102,6 +117,23 @@ end
102117

103118
### States
104119

120+
The `vi` field contains all the important information about sampling: first and foremost, the values of all the samples, but also the distributions from which they are sampled, the names of model parameters, and other metadata. As we will see below, many important steps during sampling correspond to queries or updates to `spl.state.vi`.
121+
122+
By default, you can use `SamplerState`, a concrete type defined in `inference/Inference.jl`, which extends `AbstractSamplerState` and has no field except for `vi`:
123+
124+
```julia
125+
mutable struct SamplerState{VIType<:VarInfo} <: AbstractSamplerState
126+
vi :: VIType
127+
end
128+
```
129+
130+
When doing Importance Sampling, we care not only about the values of the samples but also their weights. We will see below that the weight of each sample is also added to `spl.state.vi`. Moreover, the average $\frac 1 N \sum\limits_{j=1}^N w_i = \frac 1 N \sum\limits_{j=1}^N p(x,y \mid s_i, m_i)$ of the sample weights is a particularly important quantity:
131+
132+
* it is used to **normalize** the **empirical approximation** of the posterior distribution (TODO: link to formula)
133+
* its logarithm is the importance sampling **estimate** of the **log evidence** $\log p(x, y)$
134+
135+
To avoid having to compute it over and over again, `is.jl`defines an IS-specific concrete type `ISState` for sampler states, with an additional field `final_logevidence` containing $\log \left( \frac 1 N \sum\limits_{j=1}^N w_i \right)$.
136+
105137
```julia
106138
mutable struct ISState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState
107139
vi :: V
@@ -112,11 +144,7 @@ end
112144
ISState(model::Model) = ISState(VarInfo(model), 0.0)
113145
```
114146

115-
VarInfo contains all the important information about sampling: names of model parameters, the distributions from which they are sampled, the value of the samples, and other metadata.
116-
117-
As we will see below, many important steps during sampling correspond to queries or updates to `spl.state.vi`.
118-
119-
By default, you can use `SamplerState`, a concrete type extending `AbstractSamplerState` which has no field apart from `vi`.
147+
The following diagram summarizes the hierarchy presented above.
120148

121149
![Untitled Diagram(1)](/Users/js/Downloads/Untitled Diagram(1).png)
122150

@@ -146,8 +174,6 @@ A crude summary, which ignores things like parallelism, is the following. `sampl
146174

147175
you can of course implement all of these functions, but `AbstractMCMC` as well as Turing also provide default implementations for simple cases.
148176

149-
150-
151177
## 3. Overload `assume` and `observe`
152178

153179
The functions mentioned above, such as `sample_init!`, `step!`, etc., must of course use information about the model in order to generate samples! In particular, these functions may need **samples from distributions** defined in the model, or to **evaluate the density of these distributions** at some values of the corresponding parameters or observations.
@@ -188,24 +214,7 @@ end
188214

189215
It simply returns the density (probability, in the discrete case) of the observed value under the distribution `dist`.
190216

191-
## 4. Example: Importance Sampling
192-
193-
### Quick description of Importance Sampling
194-
195-
Consider the model defined in the first code block. Mathematically, it can be written:
196-
$$
197-
\begin{align}
198-
s &\sim \text{InverseGamma}(2, 3) \\
199-
m &\sim \text{Normal}(0, \sqrt{s}) \\
200-
x &\sim \text{Normal}(m, \sqrt{s}) \\
201-
y &\sim \text{Normal}(m, \sqrt{s})
202-
\end{align}
203-
$$
204-
The **latent** variables are $s$ and $m$, the **observed** variables are $x$ and $y$. The model **joint** distribution $p(s,m,x,y)$ decomposes into the **prior** $p(s,m)$ and the **likelihood** $p(x,y \mid s,m)$. Since $x = 1.5$ and $y = 2$ are observed, the goal is to infer the **posterior** distribution $p(s,m \mid x,y)$.
205-
206-
Importance Sampling produces samples $(s_i, m_i)$ that are independent and all distributed according to the prior distribution. It also outputs unnormalized weights $w_i = \frac {p(x,y,s_i,m_i)} {p(s_i, m_i)} = p(x,y \mid s_i, m_i)$ such that the empirical distribution $\frac 1 N \sum\limits_{i =1}^N \frac {w_i} {\sum\limits_{j=1}^N w_j} \delta_{(s_i, m_i)}$ is a good approximation of the posterior.
207-
208-
### Understanding `is.jl` step by step
217+
## 4. Summary: Importance Sampling step by step
209218

210219
We focus on the AbstractMCMC functions that are overriden in `is.jl` and executed inside `mcmcsample`: `step!`, which is called `n_samples` times, and `sample_end!`, which is executed once after those `n_samples` iterations.
211220

@@ -220,5 +229,4 @@ We focus on the AbstractMCMC functions that are overriden in `is.jl` and execute
220229
* the transition's `vi` field is simply `spl.state.vi`
221230
* the `lp` field contains the likelihood `spl.state.vi.logp[]`
222231
* When the, `n_samples` iterations are completed, `sample_end!` fills the `final_logevidence` field of `spl.state`
223-
* the **true log evidence** is $\log p(x, y)$
224-
* its **importance sampling estimate** is the logarithm of the average of the likelihoods $p(x,y \mid s_i, m_i)$
232+
* it simply takes the logarithm of the average of the sample weights, using the log weights for numerical stability

0 commit comments

Comments
 (0)