@@ -245,10 +245,49 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
245245 def predict (
246246 self ,
247247 data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
248- point_estimate : bool = True ,
249248 ):
250249 """
251- Uses model to predict on unseen data.
250+ Uses model to predict on unseen data and return point prediction of all the samples
251+
252+ Parameters
253+ ---------
254+ data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
255+ It is the data we need to make prediction on using the model.
256+
257+ Returns
258+ -------
259+ returns dictionary of sample's mean of posterior predict.
260+
261+ Examples
262+ --------
263+ >>> data, model_config, sampler_config = LinearModel.create_sample_input()
264+ >>> model = LinearModel(model_config, sampler_config)
265+ >>> idata = model.fit(data)
266+ >>> x_pred = []
267+ >>> prediction_data = pd.DataFrame({'input':x_pred})
268+ # point predict
269+ >>> pred_mean = model.predict(prediction_data)
270+ """
271+
272+ if data_prediction is not None : # set new input data
273+ self ._data_setter (data_prediction )
274+
275+ with self .model : # sample with new input data
276+ post_pred = pm .sample_posterior_predictive (self .idata )
277+
278+ # reshape output
279+ post_pred = self ._extract_samples (post_pred )
280+ for key in post_pred :
281+ post_pred [key ] = post_pred [key ].mean (axis = 0 )
282+
283+ return post_pred
284+
285+ def predict_posterior (
286+ self ,
287+ data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
288+ ):
289+ """
290+ Uses model to predict samples on unseen data.
252291
253292 Parameters
254293 ---------
@@ -268,10 +307,8 @@ def predict(
268307 >>> idata = model.fit(data)
269308 >>> x_pred = []
270309 >>> prediction_data = pd.DataFrame({'input':x_pred})
271- # only point estimate
272- >>> pred_mean = model.predict(prediction_data)
273310 # samples
274- >>> pred_samples = model.predict (prediction_data, point_estimate=False )
311+ >>> pred_mean = model.predict_posterior (prediction_data)
275312 """
276313
277314 if data_prediction is not None : # set new input data
@@ -282,9 +319,6 @@ def predict(
282319
283320 # reshape output
284321 post_pred = self ._extract_samples (post_pred )
285- if point_estimate : # average, if point-like estimate desired
286- for key in post_pred :
287- post_pred [key ] = post_pred [key ].mean (axis = 0 )
288322
289323 return post_pred
290324
0 commit comments