1818
1919from pymc .aesaraf import floatX
2020from pymc .backends .report import SamplerWarning , WarningType
21- from pymc .math import logbern , logdiffexp_numpy
21+ from pymc .math import logbern
2222from pymc .step_methods .arraystep import Competence
2323from pymc .step_methods .hmc .base_hmc import BaseHMC , DivergenceInfo , HMCStepData
2424from pymc .step_methods .hmc .integration import IntegrationError
@@ -78,6 +78,12 @@ class NUTS(BaseHMC):
7878 by the python standard library `time.perf_counter` (wall time).
7979 - `perf_counter_start`: The value of `time.perf_counter` at the beginning
8080 of the computation of the draw.
81+ - `index_in_trajectory`: This is usually only interesting for debugging
82+ purposes. This indicates the position of the posterior draw in the
83+ trajectory. Eg a -4 would indicate that the draw was the result of the
84+ fourth leapfrog step in negative direction.
85+ - `largest_eigval` and `smallest_eigval`: Experimental statistics for
86+ some mass matrix adaptation algorithms. This is nan if it is not used.
8187
8288 References
8389 ----------
@@ -105,6 +111,9 @@ class NUTS(BaseHMC):
105111 "process_time_diff" : np .float64 ,
106112 "perf_counter_diff" : np .float64 ,
107113 "perf_counter_start" : np .float64 ,
114+ "largest_eigval" : np .float64 ,
115+ "smallest_eigval" : np .float64 ,
116+ "index_in_trajectory" : np .int64 ,
108117 }
109118 ]
110119
@@ -219,12 +228,12 @@ def warnings(self):
219228
220229
221230# A proposal for the next position
222- Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept_weighted, logp " )
231+ Proposal = namedtuple ("Proposal" , "q, q_grad, energy, logp, index_in_trajectory " )
223232
224233# A subtree of the binary tree built by nuts.
225234Subtree = namedtuple (
226235 "Subtree" ,
227- "left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals " ,
236+ "left, right, p_sum, proposal, log_size" ,
228237)
229238
230239
@@ -252,10 +261,10 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
252261 self .start_energy = np .array (start .energy )
253262
254263 self .left = self .right = start
255- self .proposal = Proposal (start .q .data , start .q_grad , start .energy , 1.0 , start .model_logp )
264+ self .proposal = Proposal (start .q .data , start .q_grad , start .energy , start .model_logp , 0 )
256265 self .depth = 0
257266 self .log_size = 0
258- self .log_weighted_accept_sum = - np .inf
267+ self .log_accept_sum = - np .inf
259268 self .mean_tree_accept = 0.0
260269 self .n_proposals = 0
261270 self .p_sum = start .p .data .copy ()
@@ -279,7 +288,7 @@ def extend(self, direction):
279288 )
280289 leftmost_begin , leftmost_end = self .left , self .right
281290 rightmost_begin , rightmost_end = tree .left , tree .right
282- leftmost_p_sum = self .p_sum
291+ leftmost_p_sum = self .p_sum . copy ()
283292 rightmost_p_sum = tree .p_sum
284293 self .right = tree .right
285294 else :
@@ -289,11 +298,10 @@ def extend(self, direction):
289298 leftmost_begin , leftmost_end = tree .right , tree .left
290299 rightmost_begin , rightmost_end = self .left , self .right
291300 leftmost_p_sum = tree .p_sum
292- rightmost_p_sum = self .p_sum
301+ rightmost_p_sum = self .p_sum . copy ()
293302 self .left = tree .right
294303
295304 self .depth += 1
296- self .n_proposals += tree .n_proposals
297305
298306 if diverging or turning :
299307 return diverging , turning
@@ -303,9 +311,6 @@ def extend(self, direction):
303311 self .proposal = tree .proposal
304312
305313 self .log_size = np .logaddexp (self .log_size , tree .log_size )
306- self .log_weighted_accept_sum = np .logaddexp (
307- self .log_weighted_accept_sum , tree .log_weighted_accept_sum
308- )
309314 self .p_sum [:] += tree .p_sum
310315
311316 # Additional turning check only when tree depth > 0 to avoid redundant work
@@ -336,30 +341,30 @@ def _single_step(self, left, epsilon):
336341 if np .isnan (energy_change ):
337342 energy_change = np .inf
338343
344+ self .log_accept_sum = np .logaddexp (self .log_accept_sum , min (0 , - energy_change ))
345+
339346 if np .abs (energy_change ) > np .abs (self .max_energy_change ):
340347 self .max_energy_change = energy_change
341- if np . abs ( energy_change ) < self .Emax :
348+ if energy_change < self .Emax :
342349 # Acceptance statistic
343350 # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
344351 # Saturated Metropolis accept probability with Boltzmann weight
345- # if h - H0 < 0
346- log_p_accept_weighted = - energy_change + min (0.0 , - energy_change )
347352 log_size = - energy_change
348353 proposal = Proposal (
349354 right .q .data ,
350355 right .q_grad ,
351356 right .energy ,
352- log_p_accept_weighted ,
353357 right .model_logp ,
358+ right .index_in_trajectory ,
354359 )
355- tree = Subtree (
356- right , right , right .p .data , proposal , log_size , log_p_accept_weighted , 1
357- )
360+ tree = Subtree (right , right , right .p .data , proposal , log_size )
358361 return tree , None , False
359362 else :
360363 error_msg = f"Energy change in leapfrog step is too large: { energy_change } ."
361364 error = None
362- tree = Subtree (None , None , None , None , - np .inf , - np .inf , 1 )
365+ finally :
366+ self .n_proposals += 1
367+ tree = Subtree (None , None , None , None , - np .inf )
363368 divergance_info = DivergenceInfo (error_msg , error , left , right )
364369 return tree , divergance_info , False
365370
@@ -387,31 +392,20 @@ def _build_subtree(self, left, depth, epsilon):
387392 turning = turning | turning1 | turning2
388393
389394 log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
390- log_weighted_accept_sum = np .logaddexp (
391- tree1 .log_weighted_accept_sum , tree2 .log_weighted_accept_sum
392- )
393395 if logbern (tree2 .log_size - log_size ):
394396 proposal = tree2 .proposal
395397 else :
396398 proposal = tree1 .proposal
397399 else :
398400 p_sum = tree1 .p_sum
399401 log_size = tree1 .log_size
400- log_weighted_accept_sum = tree1 .log_weighted_accept_sum
401402 proposal = tree1 .proposal
402403
403- n_proposals = tree1 .n_proposals + tree2 .n_proposals
404-
405- tree = Subtree (left , right , p_sum , proposal , log_size , log_weighted_accept_sum , n_proposals )
404+ tree = Subtree (left , right , p_sum , proposal , log_size )
406405 return tree , diverging , turning
407406
408407 def stats (self ):
409- # Update accept stat if any subtrees were accepted
410- if self .log_size > 0 :
411- # Remove contribution from initial state which is always a perfect
412- # accept
413- log_sum_weight = logdiffexp_numpy (self .log_size , 0.0 )
414- self .mean_tree_accept = np .exp (self .log_weighted_accept_sum - log_sum_weight )
408+ self .mean_tree_accept = np .exp (self .log_accept_sum ) / self .n_proposals
415409 return {
416410 "depth" : self .depth ,
417411 "mean_tree_accept" : self .mean_tree_accept ,
@@ -420,4 +414,5 @@ def stats(self):
420414 "tree_size" : self .n_proposals ,
421415 "max_energy_error" : self .max_energy_change ,
422416 "model_logp" : self .proposal .logp ,
417+ "index_in_trajectory" : self .proposal .index_in_trajectory ,
423418 }
0 commit comments