|
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 |
|
17 | | -from pandas import DataFrame, Series |
| 17 | +from aesara.tensor.random.op import RandomVariable, default_shape_from_params |
18 | 18 |
|
19 | 19 | from pymc3.distributions.distribution import NoDistribution |
20 | | -from pymc3.distributions.tree import LeafNode, SplitNode, Tree |
21 | 20 |
|
22 | 21 | __all__ = ["BART"] |
23 | 22 |
|
24 | 23 |
|
25 | | -class BaseBART(NoDistribution): |
26 | | - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): |
27 | | - |
28 | | - self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y) |
29 | | - |
30 | | - super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs) |
31 | | - |
32 | | - if self.X.ndim != 2: |
33 | | - raise ValueError("The design matrix X must have two dimensions") |
34 | | - |
35 | | - if self.Y.ndim != 1: |
36 | | - raise ValueError("The response matrix Y must have one dimension") |
37 | | - if self.X.shape[0] != self.Y.shape[0]: |
38 | | - raise ValueError( |
39 | | - "The design matrix X and the response matrix Y must have the same number of elements" |
40 | | - ) |
41 | | - if not isinstance(m, int): |
42 | | - raise ValueError("The number of trees m type must be int") |
43 | | - if m < 1: |
44 | | - raise ValueError("The number of trees m must be greater than zero") |
45 | | - |
46 | | - if alpha <= 0 or 1 <= alpha: |
47 | | - raise ValueError( |
48 | | - "The value for the alpha parameter for the tree structure " |
49 | | - "must be in the interval (0, 1)" |
50 | | - ) |
51 | | - |
52 | | - self.num_observations = X.shape[0] |
53 | | - self.num_variates = X.shape[1] |
54 | | - self.available_predictors = list(range(self.num_variates)) |
55 | | - self.ssv = SampleSplittingVariable(split_prior, self.num_variates) |
56 | | - self.m = m |
57 | | - self.alpha = alpha |
58 | | - self.trees = self.init_list_of_trees() |
59 | | - self.all_trees = [] |
60 | | - self.mean = fast_mean() |
61 | | - self.prior_prob_leaf_node = compute_prior_probability(alpha) |
62 | | - |
63 | | - def preprocess_XY(self, X, Y): |
64 | | - if isinstance(Y, (Series, DataFrame)): |
65 | | - Y = Y.to_numpy() |
66 | | - if isinstance(X, (Series, DataFrame)): |
67 | | - X = X.to_numpy() |
68 | | - missing_data = np.any(np.isnan(X)) |
69 | | - X = np.random.normal(X, np.std(X, 0) / 100) |
70 | | - return X, Y, missing_data |
71 | | - |
72 | | - def init_list_of_trees(self): |
73 | | - initial_value_leaf_nodes = self.Y.mean() / self.m |
74 | | - initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") |
75 | | - list_of_trees = [] |
76 | | - for i in range(self.m): |
77 | | - new_tree = Tree.init_tree( |
78 | | - tree_id=i, |
79 | | - leaf_node_value=initial_value_leaf_nodes, |
80 | | - idx_data_points=initial_idx_data_points_leaf_nodes, |
81 | | - ) |
82 | | - list_of_trees.append(new_tree) |
83 | | - # Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J. |
84 | | - # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 |
85 | | - # The sum_trees_output will contain the sum of the predicted output for all trees. |
86 | | - # When R_j is needed we subtract the current predicted output for tree T_j. |
87 | | - self.sum_trees_output = np.full_like(self.Y, self.Y.mean()) |
88 | | - |
89 | | - return list_of_trees |
90 | | - |
91 | | - def __iter__(self): |
92 | | - return iter(self.trees) |
93 | | - |
94 | | - def __repr_latex(self): |
95 | | - raise NotImplementedError |
96 | | - |
97 | | - def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): |
98 | | - x_j = self.X[idx_data_points_split_node, idx_split_variable] |
99 | | - if self.missing_data: |
100 | | - x_j = x_j[~np.isnan(x_j)] |
101 | | - values = np.unique(x_j) |
102 | | - # The last value is never available as it would leave the right subtree empty. |
103 | | - return values[:-1] |
104 | | - |
105 | | - def grow_tree(self, tree, index_leaf_node): |
106 | | - current_node = tree.get_node(index_leaf_node) |
107 | | - |
108 | | - index_selected_predictor = self.ssv.rvs() |
109 | | - selected_predictor = self.available_predictors[index_selected_predictor] |
110 | | - available_splitting_rules = self.get_available_splitting_rules( |
111 | | - current_node.idx_data_points, selected_predictor |
112 | | - ) |
113 | | - # This can be unsuccessful when there are not available splitting rules |
114 | | - if available_splitting_rules.size == 0: |
115 | | - return False, None |
116 | | - |
117 | | - index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) |
118 | | - selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] |
119 | | - new_split_node = SplitNode( |
120 | | - index=index_leaf_node, |
121 | | - idx_split_variable=selected_predictor, |
122 | | - split_value=selected_splitting_rule, |
123 | | - ) |
124 | | - |
125 | | - left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points( |
126 | | - new_split_node, current_node.idx_data_points |
127 | | - ) |
128 | | - |
129 | | - left_node_value = self.draw_leaf_value(left_node_idx_data_points) |
130 | | - right_node_value = self.draw_leaf_value(right_node_idx_data_points) |
131 | | - |
132 | | - new_left_node = LeafNode( |
133 | | - index=current_node.get_idx_left_child(), |
134 | | - value=left_node_value, |
135 | | - idx_data_points=left_node_idx_data_points, |
136 | | - ) |
137 | | - new_right_node = LeafNode( |
138 | | - index=current_node.get_idx_right_child(), |
139 | | - value=right_node_value, |
140 | | - idx_data_points=right_node_idx_data_points, |
141 | | - ) |
142 | | - tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) |
143 | | - |
144 | | - return True, index_selected_predictor |
145 | | - |
146 | | - def get_new_idx_data_points(self, current_split_node, idx_data_points): |
147 | | - idx_split_variable = current_split_node.idx_split_variable |
148 | | - split_value = current_split_node.split_value |
149 | | - |
150 | | - left_idx = self.X[idx_data_points, idx_split_variable] <= split_value |
151 | | - left_node_idx_data_points = idx_data_points[left_idx] |
152 | | - right_node_idx_data_points = idx_data_points[~left_idx] |
153 | | - |
154 | | - return left_node_idx_data_points, right_node_idx_data_points |
155 | | - |
156 | | - def get_residuals(self): |
157 | | - """Compute the residuals.""" |
158 | | - R_j = self.Y - self.sum_trees_output |
159 | | - return R_j |
160 | | - |
161 | | - def get_residuals_loo(self, tree): |
162 | | - """Compute the residuals without leaving the passed tree out.""" |
163 | | - R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations)) |
164 | | - return R_j |
165 | | - |
166 | | - def draw_leaf_value(self, idx_data_points): |
167 | | - """Draw the residual mean.""" |
168 | | - R_j = self.get_residuals()[idx_data_points] |
169 | | - draw = self.mean(R_j) |
170 | | - return draw |
171 | | - |
172 | | - def predict(self, X_new): |
173 | | - """Compute out of sample predictions evaluated at X_new""" |
174 | | - trees = self.all_trees |
175 | | - num_observations = X_new.shape[0] |
176 | | - pred = np.zeros((len(trees), num_observations)) |
177 | | - np.random.randint(len(trees)) |
178 | | - for draw, trees_to_sum in enumerate(trees): |
179 | | - new_Y = np.zeros(num_observations) |
180 | | - for tree in trees_to_sum: |
181 | | - new_Y += [tree.predict_out_of_sample(x) for x in X_new] |
182 | | - pred[draw] = new_Y |
183 | | - return pred |
184 | | - |
185 | | - |
186 | | -def compute_prior_probability(alpha): |
| 24 | +class BARTRV(RandomVariable): |
187 | 25 | """ |
188 | | - Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). |
189 | | - Taken from equation 19 in [Rockova2018]. |
190 | | -
|
191 | | - Parameters |
192 | | - ---------- |
193 | | - alpha : float |
194 | | -
|
195 | | - Returns |
196 | | - ------- |
197 | | - list with probabilities for leaf nodes |
198 | | -
|
199 | | - References |
200 | | - ---------- |
201 | | - .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. |
202 | | - arXiv, `link <https://arxiv.org/abs/1810.00787>`__ |
| 26 | + Base class for BART |
203 | 27 | """ |
204 | | - prior_leaf_prob = [0] |
205 | | - depth = 1 |
206 | | - while prior_leaf_prob[-1] < 1: |
207 | | - prior_leaf_prob.append(1 - alpha ** depth) |
208 | | - depth += 1 |
209 | | - return prior_leaf_prob |
210 | | - |
211 | | - |
212 | | -def fast_mean(): |
213 | | - """If available use Numba to speed up the computation of the mean.""" |
214 | | - try: |
215 | | - from numba import jit |
216 | | - except ImportError: |
217 | | - return np.mean |
218 | | - |
219 | | - @jit |
220 | | - def mean(a): |
221 | | - count = a.shape[0] |
222 | | - suma = 0 |
223 | | - for i in range(count): |
224 | | - suma += a[i] |
225 | | - return suma / count |
226 | | - |
227 | | - return mean |
228 | 28 |
|
| 29 | + name = "BART" |
| 30 | + ndim_supp = 1 |
| 31 | + ndims_params = [2, 1, 0, 0, 0, 1] |
| 32 | + dtype = "floatX" |
| 33 | + _print_name = ("BART", "\\operatorname{BART}") |
| 34 | + all_trees = None |
| 35 | + |
| 36 | + def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): |
| 37 | + return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): |
| 41 | + size = kwargs.pop("size", None) |
| 42 | + X_new = kwargs.pop("X_new", None) |
| 43 | + all_trees = cls.all_trees |
| 44 | + if all_trees: |
| 45 | + |
| 46 | + if size is None: |
| 47 | + size = () |
| 48 | + elif isinstance(size, int): |
| 49 | + size = [size] |
| 50 | + |
| 51 | + flatten_size = 1 |
| 52 | + for s in size: |
| 53 | + flatten_size *= s |
| 54 | + |
| 55 | + idx = rng.randint(len(all_trees), size=flatten_size) |
| 56 | + |
| 57 | + if X_new is None: |
| 58 | + pred = np.zeros((flatten_size, all_trees[0][0].num_observations)) |
| 59 | + for ind, p in enumerate(pred): |
| 60 | + for tree in all_trees[idx[ind]]: |
| 61 | + p += tree.predict_output() |
| 62 | + else: |
| 63 | + pred = np.zeros((flatten_size, X_new.shape[0])) |
| 64 | + for ind, p in enumerate(pred): |
| 65 | + for tree in all_trees[idx[ind]]: |
| 66 | + p += np.array([tree.predict_out_of_sample(x) for x in X_new]) |
| 67 | + return pred.reshape((*size, -1)) |
| 68 | + else: |
| 69 | + return np.full_like(cls.Y, cls.Y.mean()) |
229 | 70 |
|
230 | | -def discrete_uniform_sampler(upper_value): |
231 | | - """Draw from the uniform distribution with bounds [0, upper_value).""" |
232 | | - return int(np.random.random() * upper_value) |
233 | | - |
234 | | - |
235 | | -class SampleSplittingVariable: |
236 | | - def __init__(self, prior, num_variates): |
237 | | - self.prior = prior |
238 | | - self.num_variates = num_variates |
239 | | - |
240 | | - if self.prior is not None: |
241 | | - self.prior = np.asarray(self.prior) |
242 | | - self.prior = self.prior / self.prior.sum() |
243 | | - if self.prior.size != self.num_variates: |
244 | | - raise ValueError( |
245 | | - f"The size of split_prior ({self.prior.size}) should be the " |
246 | | - f"same as the number of covariates ({self.num_variates})" |
247 | | - ) |
248 | | - self.enu = list(enumerate(np.cumsum(self.prior))) |
249 | 71 |
|
250 | | - def rvs(self): |
251 | | - if self.prior is None: |
252 | | - return int(np.random.random() * self.num_variates) |
253 | | - else: |
254 | | - r = np.random.random() |
255 | | - for i, v in self.enu: |
256 | | - if r <= v: |
257 | | - return i |
| 72 | +bart = BARTRV() |
258 | 73 |
|
259 | 74 |
|
260 | | -class BART(BaseBART): |
| 75 | +class BART(NoDistribution): |
261 | 76 | """ |
262 | | - BART distribution. |
| 77 | + Bayesian Additive Regression Tree distribution. |
263 | 78 |
|
264 | 79 | Distribution representing a sum over trees |
265 | 80 |
|
266 | 81 | Parameters |
267 | 82 | ---------- |
268 | 83 | X : array-like |
269 | | - The design matrix. |
| 84 | + The covariate matrix. |
270 | 85 | Y : array-like |
271 | 86 | The response vector. |
272 | 87 | m : int |
273 | 88 | Number of trees |
274 | 89 | alpha : float |
275 | | - Control the prior probability over the depth of the trees. Must be in the interval (0, 1), |
276 | | - altought it is recomenned to be in the interval (0, 0.5]. |
| 90 | + Control the prior probability over the depth of the trees. Even when it can takes values in |
| 91 | + the interval (0, 1), it is recommended to be in the interval (0, 0.5]. |
| 92 | + k : float |
| 93 | + Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1 |
| 94 | + and 3. |
277 | 95 | split_prior : array-like |
278 | | - Each element of split_prior should be in the [0, 1] interval and the elements should sum |
279 | | - to 1. Otherwise they will be normalized. |
280 | | - Defaults to None, all variable have the same a prior probability |
| 96 | + Each element of split_prior should be in the [0, 1] interval and the elements should sum to |
| 97 | + 1. Otherwise they will be normalized. |
| 98 | + Defaults to None, i.e. all covariates have the same prior probability to be selected. |
281 | 99 | """ |
282 | 100 |
|
283 | | - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): |
284 | | - super().__init__(X, Y, m, alpha, split_prior) |
| 101 | + def __new__( |
| 102 | + cls, |
| 103 | + name, |
| 104 | + X, |
| 105 | + Y, |
| 106 | + m=50, |
| 107 | + alpha=0.25, |
| 108 | + k=2, |
| 109 | + split_prior=None, |
| 110 | + **kwargs, |
| 111 | + ): |
| 112 | + |
| 113 | + cls.all_trees = [] |
| 114 | + |
| 115 | + bart_op = type( |
| 116 | + f"BART_{name}", |
| 117 | + (BARTRV,), |
| 118 | + dict( |
| 119 | + name="BART", |
| 120 | + all_trees=cls.all_trees, |
| 121 | + inplace=False, |
| 122 | + initval=Y.mean(), |
| 123 | + X=X, |
| 124 | + Y=Y, |
| 125 | + m=m, |
| 126 | + alpha=alpha, |
| 127 | + k=k, |
| 128 | + split_prior=split_prior, |
| 129 | + ), |
| 130 | + )() |
| 131 | + |
| 132 | + NoDistribution.register(BARTRV) |
| 133 | + |
| 134 | + cls.rv_op = bart_op |
| 135 | + params = [X, Y, m, alpha, k] |
| 136 | + return super().__new__(cls, name, *params, **kwargs) |
| 137 | + |
| 138 | + @classmethod |
| 139 | + def dist(cls, *params, **kwargs): |
| 140 | + return super().dist(params, **kwargs) |
0 commit comments