1818@author: johnsalvatier
1919"""
2020import sys
21+ import warnings
2122
22- from typing import Optional
23+ from typing import Optional , Sequence
2324
2425import aesara .gradient as tg
2526import numpy as np
2627
28+ from aesara import Variable
2729from fastprogress .fastprogress import ProgressBar , progress_bar
2830from numpy import isfinite
2931from scipy .optimize import minimize
4143
4244def find_MAP (
4345 start = None ,
44- vars = None ,
46+ vars : Optional [ Sequence [ Variable ]] = None ,
4547 method = "L-BFGS-B" ,
4648 return_raw = False ,
4749 include_transformed = True ,
@@ -61,20 +63,23 @@ def find_MAP(
6163 Parameters
6264 ----------
6365 start: `dict` of parameter values (Defaults to `model.initial_point`)
64- vars: list
65- List of variables to optimize and set to optimum (Defaults to all continuous).
66- method: string or callable
67- Optimization algorithm (Defaults to 'L-BFGS-B' unless
68- discrete variables are specified in `vars`, then
69- `Powell` which will perform better). For instructions on use of a callable,
70- refer to SciPy's documentation of `optimize.minimize`.
71- return_raw: bool
72- Whether to return the full output of scipy.optimize.minimize (Defaults to `False`)
66+ These values will be fixed and used for any free RandomVariables that are
67+ not being optimized.
68+ vars: list of TensorVariable
69+ List of free RandomVariables to optimize the posterior with respect to.
70+ Defaults to all continuous RVs in a model. The respective value variables
71+ may also be passed instead.
72+ method: string or callable, optional
73+ Optimization algorithm. Defaults to 'L-BFGS-B' unless discrete variables are
74+ specified in `vars`, then `Powell` which will perform better. For instructions
75+ on use of a callable, refer to SciPy's documentation of `optimize.minimize`.
76+ return_raw: bool, optional defaults to False
77+ Whether to return the full output of scipy.optimize.minimize
7378 include_transformed: bool, optional defaults to True
74- Flag for reporting automatically transformed variables in addition
75- to original variables.
79+ Flag for reporting automatically unconstrained transformed values in addition
80+ to the constrained values
7681 progressbar: bool, optional defaults to True
77- Whether or not to display a progress bar in the command line.
82+ Whether to display a progress bar in the command line.
7883 maxeval: int, optional, defaults to 5000
7984 The maximum number of times the posterior distribution is evaluated.
8085 model: Model (optional if in `with` context)
@@ -95,7 +100,21 @@ def find_MAP(
95100 if not vars :
96101 raise ValueError ("Model has no unobserved continuous variables." )
97102 else :
98- vars = get_value_vars_from_user_vars (vars , model )
103+ try :
104+ vars = get_value_vars_from_user_vars (vars , model )
105+ except ValueError as exc :
106+ # Accomodate case where user passed non-pure RV nodes
107+ vars = pm .inputvars (pm .aesaraf .rvs_to_value_vars (vars ))
108+ if vars :
109+ # Make sure they belong to current model again...
110+ vars = get_value_vars_from_user_vars (vars , model )
111+ warnings .warn (
112+ "Intermediate variables (such as Deterministic or Potential) were passed. "
113+ "find_MAP will optimize the underlying free_RVs instead." ,
114+ UserWarning ,
115+ )
116+ else :
117+ raise exc
99118
100119 disc_vars = list (typefilter (vars , discrete_types ))
101120 ipfn = make_initial_point_fn (
0 commit comments