@@ -53,9 +53,11 @@ def sample_tree_sequential(
5353 missing_data ,
5454 sum_trees_output ,
5555 mean ,
56+ linear_fit ,
5657 m ,
5758 normal ,
5859 mu_std ,
60+ response ,
5961 ):
6062 tree_grew = False
6163 if self .expansion_nodes :
@@ -73,9 +75,11 @@ def sample_tree_sequential(
7375 missing_data ,
7476 sum_trees_output ,
7577 mean ,
78+ linear_fit ,
7679 m ,
7780 normal ,
7881 mu_std ,
82+ response ,
7983 )
8084 if tree_grew :
8185 new_indexes = self .tree .idx_leaf_nodes [- 2 :]
@@ -97,11 +101,17 @@ class PGBART(ArrayStepShared):
97101 Number of particles for the conditional SMC sampler. Defaults to 10
98102 max_stages : int
99103 Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
100- chunk = int
101- Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees.
104+ batch : int
105+ Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106+ during tuning and 20% after tuning.
102107 model: PyMC Model
103108 Optional model for sampling step. Defaults to None (taken from context).
104109
110+ Note
111+ ----
112+ This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces
113+ several changes. The changes will be properly documented soon.
114+
105115 References
106116 ----------
107117 .. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015),
@@ -114,7 +124,7 @@ class PGBART(ArrayStepShared):
114124 generates_stats = True
115125 stats_dtypes = [{"variable_inclusion" : np .ndarray }]
116126
117- def __init__ (self , vars = None , num_particles = 10 , max_stages = 100 , chunk = "auto" , model = None ):
127+ def __init__ (self , vars = None , num_particles = 10 , max_stages = 100 , batch = "auto" , model = None ):
118128 _log .warning ("BART is experimental. Use with caution." )
119129 model = modelcontext (model )
120130 initial_values = model .initial_point
@@ -125,6 +135,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
125135 self .m = self .bart .m
126136 self .alpha = self .bart .alpha
127137 self .k = self .bart .k
138+ self .response = self .bart .response
128139 self .split_prior = self .bart .split_prior
129140 if self .split_prior is None :
130141 self .split_prior = np .ones (self .X .shape [1 ])
@@ -149,6 +160,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
149160 idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
150161 )
151162 self .mean = fast_mean ()
163+ self .linear_fit = fast_linear_fit ()
164+
152165 self .normal = NormalSampler ()
153166 self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
154167 self .ssv = SampleSplittingVariable (self .split_prior )
@@ -157,10 +170,10 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
157170 self .idx = 0
158171 self .iter = 0
159172 self .sum_trees = []
160- self .chunk = chunk
173+ self .batch = batch
161174
162- if self .chunk == "auto" :
163- self .chunk = max (1 , int (self .m * 0.1 ))
175+ if self .batch == "auto" :
176+ self .batch = max (1 , int (self .m * 0.1 ))
164177 self .log_num_particles = np .log (num_particles )
165178 self .indices = list (range (1 , num_particles ))
166179 self .len_indices = len (self .indices )
@@ -190,7 +203,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
190203 if self .idx == self .m :
191204 self .idx = 0
192205
193- for tree_id in range (self .idx , self .idx + self .chunk ):
206+ for tree_id in range (self .idx , self .idx + self .batch ):
194207 if tree_id >= self .m :
195208 break
196209 # Generate an initial set of SMC particles
@@ -213,9 +226,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
213226 self .missing_data ,
214227 sum_trees_output ,
215228 self .mean ,
229+ self .linear_fit ,
216230 self .m ,
217231 self .normal ,
218232 self .mu_std ,
233+ self .response ,
219234 )
220235 if tree_grew :
221236 self .update_weight (p )
@@ -251,6 +266,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
251266 self .split_prior [index ] += 1
252267 self .ssv = SampleSplittingVariable (self .split_prior )
253268 else :
269+ self .batch = max (1 , int (self .m * 0.2 ))
254270 self .iter += 1
255271 self .sum_trees .append (new_tree )
256272 if not self .iter % self .m :
@@ -389,16 +405,20 @@ def grow_tree(
389405 missing_data ,
390406 sum_trees_output ,
391407 mean ,
408+ linear_fit ,
392409 m ,
393410 normal ,
394411 mu_std ,
412+ response ,
395413):
396414 current_node = tree .get_node (index_leaf_node )
415+ idx_data_points = current_node .idx_data_points
397416
398417 index_selected_predictor = ssv .rvs ()
399418 selected_predictor = available_predictors [index_selected_predictor ]
400- available_splitting_values = X [current_node . idx_data_points , selected_predictor ]
419+ available_splitting_values = X [idx_data_points , selected_predictor ]
401420 if missing_data :
421+ idx_data_points = idx_data_points [~ np .isnan (available_splitting_values )]
402422 available_splitting_values = available_splitting_values [
403423 ~ np .isnan (available_splitting_values )
404424 ]
@@ -407,58 +427,82 @@ def grow_tree(
407427 return False , None
408428
409429 idx_selected_splitting_values = discrete_uniform_sampler (len (available_splitting_values ))
410- selected_splitting_rule = available_splitting_values [idx_selected_splitting_values ]
430+ split_value = available_splitting_values [idx_selected_splitting_values ]
411431 new_split_node = SplitNode (
412432 index = index_leaf_node ,
413433 idx_split_variable = selected_predictor ,
414- split_value = selected_splitting_rule ,
434+ split_value = split_value ,
415435 )
416436
417437 left_node_idx_data_points , right_node_idx_data_points = get_new_idx_data_points (
418- new_split_node , current_node . idx_data_points , X
438+ split_value , idx_data_points , selected_predictor , X
419439 )
420440
421- left_node_value = draw_leaf_value (
422- sum_trees_output [left_node_idx_data_points ], mean , m , normal , mu_std
441+ if response == "mix" :
442+ response = "linear" if np .random .random () >= 0.5 else "constant"
443+
444+ left_node_value , left_node_linear_params = draw_leaf_value (
445+ sum_trees_output [left_node_idx_data_points ],
446+ X [left_node_idx_data_points , selected_predictor ],
447+ mean ,
448+ linear_fit ,
449+ m ,
450+ normal ,
451+ mu_std ,
452+ response ,
423453 )
424- right_node_value = draw_leaf_value (
425- sum_trees_output [right_node_idx_data_points ], mean , m , normal , mu_std
454+ right_node_value , right_node_linear_params = draw_leaf_value (
455+ sum_trees_output [right_node_idx_data_points ],
456+ X [right_node_idx_data_points , selected_predictor ],
457+ mean ,
458+ linear_fit ,
459+ m ,
460+ normal ,
461+ mu_std ,
462+ response ,
426463 )
427464
428465 new_left_node = LeafNode (
429466 index = current_node .get_idx_left_child (),
430467 value = left_node_value ,
431468 idx_data_points = left_node_idx_data_points ,
469+ linear_params = left_node_linear_params ,
432470 )
433471 new_right_node = LeafNode (
434472 index = current_node .get_idx_right_child (),
435473 value = right_node_value ,
436474 idx_data_points = right_node_idx_data_points ,
475+ linear_params = right_node_linear_params ,
437476 )
438477 tree .grow_tree (index_leaf_node , new_split_node , new_left_node , new_right_node )
439478
440479 return True , index_selected_predictor
441480
442481
443- def get_new_idx_data_points (current_split_node , idx_data_points , X ):
444- idx_split_variable = current_split_node .idx_split_variable
445- split_value = current_split_node .split_value
482+ def get_new_idx_data_points (split_value , idx_data_points , selected_predictor , X ):
446483
447- left_idx = X [idx_data_points , idx_split_variable ] <= split_value
484+ left_idx = X [idx_data_points , selected_predictor ] <= split_value
448485 left_node_idx_data_points = idx_data_points [left_idx ]
449486 right_node_idx_data_points = idx_data_points [~ left_idx ]
450487
451488 return left_node_idx_data_points , right_node_idx_data_points
452489
453490
454- def draw_leaf_value (sum_trees_output_idx , mean , m , normal , mu_std ):
491+ def draw_leaf_value (Y_mu_pred , X_mu , mean , linear_fit , m , normal , mu_std , response ):
455492 """Draw Gaussian distributed leaf values"""
456- if sum_trees_output_idx .size == 0 :
457- return 0
493+ linear_params = None
494+ if Y_mu_pred .size == 0 :
495+ return 0 , linear_params
496+ elif Y_mu_pred .size == 1 :
497+ mu_mean = Y_mu_pred .item () / m
458498 else :
459- mu_mean = mean (sum_trees_output_idx ) / m
460- draw = normal .random () * mu_std + mu_mean
461- return draw
499+ if response == "constant" :
500+ mu_mean = mean (Y_mu_pred ) / m
501+ elif response == "linear" :
502+ Y_fit , linear_params = linear_fit (X_mu , Y_mu_pred )
503+ mu_mean = Y_fit / m
504+ draw = normal .random () * mu_std + mu_mean
505+ return draw , linear_params
462506
463507
464508def fast_mean ():
@@ -479,6 +523,29 @@ def mean(a):
479523 return mean
480524
481525
526+ def fast_linear_fit ():
527+ """If available use Numba to speed up the computation of the linear fit"""
528+
529+ def linear_fit (X , Y ):
530+
531+ n = len (Y )
532+ xbar = np .sum (X ) / n
533+ ybar = np .sum (Y ) / n
534+
535+ b = (X @ Y - n * xbar * ybar ) / (X @ X - n * xbar ** 2 )
536+ a = ybar - b * xbar
537+
538+ Y_fit = a + b * X
539+ return Y_fit , (a , b )
540+
541+ try :
542+ from numba import jit
543+
544+ return jit (linear_fit )
545+ except ImportError :
546+ return linear_fit
547+
548+
482549def discrete_uniform_sampler (upper_value ):
483550 """Draw from the uniform distribution with bounds [0, upper_value).
484551
0 commit comments