-
Notifications
You must be signed in to change notification settings - Fork 230
Fix log evidence computation #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1266 +/- ##
==========================================
+ Coverage 66.84% 67.08% +0.24%
==========================================
Files 25 25
Lines 1327 1343 +16
==========================================
+ Hits 887 901 +14
- Misses 440 442 +2
Continue to review full report at Codecov.
|
|
The test error on Windows is still the HMC error that #1264 is supposed to fix. |
yebai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @devmotion for fixing this - it is a very subtle issue and I'm glad that we now got it correct.
| else | ||
| # Increase the unnormalized logarithmic weights, accounting for the variables | ||
| # of other samplers. | ||
| increase_logweight!(pc, i, score + getlogp(p.vi)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a side note, getlotp(p.vi) will always return 0, since the assume and observe functions for particle samplers does not modify vi.logp by default. This doesn't affect correctness, but worth to pay attention.
See:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not completely true (or maybe I misunderstand you), in this line getlogp can actually return nonzero values due to
Turing.jl/src/inference/AdvancedSMC.jl
Line 286 in 35e3fe8
| acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) |
resetlogp! in one of the following lines, this won't show up in the saved transitions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I must have missed that line, thanks for the pointer!
| params = tonamedtuple(particle.vi) | ||
|
|
||
| # This is pretty useless since we reset the log probability continuously in the | ||
| # particle sweep. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the note, also see my comment above about assume and observe functions.
* Fix particle filters with adaptive resampling and add documentation Fixes TuringLang/DynamicPPL.jl#104 * Fix and extend tests of `ParticleContainer` * Move logevidence tests from DynamicPPL * Add more convenient constructors for Particle Gibbs * Relax type annotations * Check for approximate equality only * Add docstring and reference
This PR fixes TuringLang/DynamicPPL.jl#104 and transfers the corresponding tests to Turing. The main problem is that the formula that I based the computation of the log evidence in #1237 on is only valid if resampling is performed in every time step, or rather if the weights are reset to 1/N before every reweighting step. A reference for the more general formula (which was used before) is eq (14) in Sequential Monte Carlo Samplers by P. Del Moral, A. Doucet and A. Jasra. The fix is particularly important since by default we don't perform resampling in every step but only in an adaptive way based on the estimated ESS.
More formally, we save only the unnormalized logarithmic weights, and accumulate them until the next resampling is performed, in which case the are reset to 0. Let
logw_k^idenote the unnormalized logarithmic weight of theith particle in thekth step. Hence in the notation of Del Moral, Doucet, and Jasra we havelogw_k^i = logw_{k-1}^i + log(w_k^i), and hence by induction normalizinglogw_k^iyieldsexp(logw_k^i) / \sum_{j=1}^N exp(logw_k^j) = w_k^i * exp(logw_{k-1}^i) / \sum_{j=1}^N w_k^j * exp(logw_{k-1}^j) = w_k^i * W_{k-1}^i / \sum_{j=1}^N w_k^j W_{k-1}^j = W_k^i, i.e., the unnormalized weights in algorithm 3.1.1, as desired. Hence from eq. (14) we can compute the increase of the log evidence bylog(Z_k) - log(Z_{k-1}) = log(Z_k / Z_{k-1}) = log(\sum_{i=1}^N W_{k-1}^i * w_{k-1}^i) = log(\sum_{i=1}^N exp(logw_{k-1}^i) * w_{k-1}^i) - log(\sum_{i=1}^N exp(logw_{k-1}^i)) = log(\sum_{i=1}^N \exp(logw_{k-1}^i + log(w_{k-1}^i))) - log(\sum_{i=1}^N exp(logw_{k-1}^i)) = log(\sum_{i=1}^N exp(logw_k^i)) - log(\sum_{i=1}^N exp(logw_{k-1}^i)).Thus only if
logw_{k-1}^iare all 0 (such as in the initial step and after resampling) we obtain the formulalog(Z_k) - log(Z_{k-1}) = log(\sum_{i=1}^N exp(logw_k^i)) - log(N)from the reference on which the linked PR was based on.As a final remark, I'm a bit unsatisfied with the function name
logZsince, as shown above, it does not computelog(Z)but the logarithm of the normalization factor of the current unnormalized weights.