2525from aeppl .logprob import logcdf as logcdf_aeppl
2626from aeppl .logprob import logprob as logp_aeppl
2727from aeppl .tensor import MeasurableJoin
28- from aeppl .transforms import TransformValuesRewrite
28+ from aeppl .transforms import RVTransform , TransformValuesRewrite
2929from aesara import tensor as at
3030from aesara .graph .basic import graph_inputs , io_toposort
3131from aesara .tensor .random .op import RandomVariable
3232from aesara .tensor .var import TensorVariable
3333
3434from pymc .aesaraf import constant_fold , floatX
3535
36+ TOTAL_SIZE = Union [int , Sequence [int ], None ]
3637
37- def _get_scaling (
38- total_size : Optional [Union [int , Sequence [int ]]], shape , ndim : int
39- ) -> TensorVariable :
38+
39+ def _get_scaling (total_size : TOTAL_SIZE , shape , ndim : int ) -> TensorVariable :
4040 """
4141 Gets scaling constant for logp.
4242
@@ -104,12 +104,26 @@ def _get_scaling(
104104 return at .as_tensor (coef , dtype = aesara .config .floatX )
105105
106106
107- def joint_logpt (* args , ** kwargs ):
108- warnings .warn (
109- "joint_logpt has been deprecated. Use joint_logp instead." ,
110- FutureWarning ,
111- )
112- return joint_logp (* args , ** kwargs )
107+ def _check_no_rvs (logp_terms : Sequence [TensorVariable ]):
108+ # Raise if there are unexpected RandomVariables in the logp graph
109+ # Only SimulatorRVs are allowed
110+ from pymc .distributions .simulator import SimulatorRV
111+
112+ unexpected_rv_nodes = [
113+ node
114+ for node in aesara .graph .ancestors (logp_terms )
115+ if (
116+ node .owner
117+ and isinstance (node .owner .op , RandomVariable )
118+ and not isinstance (node .owner .op , SimulatorRV )
119+ )
120+ ]
121+ if unexpected_rv_nodes :
122+ raise ValueError (
123+ f"Random variables detected in the logp graph: { unexpected_rv_nodes } .\n "
124+ "This can happen when DensityDist logp or Interval transform functions "
125+ "reference nonlocal variables."
126+ )
113127
114128
115129def joint_logp (
@@ -151,6 +165,10 @@ def joint_logp(
151165 Sum the log-likelihood or return each term as a separate list item.
152166
153167 """
168+ warnings .warn (
169+ "joint_logp has been deprecated, use model.logp instead" ,
170+ FutureWarning ,
171+ )
154172 # TODO: In future when we drop support for tag.value_var most of the following
155173 # logic can be removed and logp can just be a wrapper function that calls aeppl's
156174 # joint_logprob directly.
@@ -223,33 +241,15 @@ def joint_logp(
223241 ** kwargs ,
224242 )
225243
226- # Raise if there are unexpected RandomVariables in the logp graph
227- # Only SimulatorRVs are allowed
228- from pymc .distributions .simulator import SimulatorRV
229-
230- unexpected_rv_nodes = [
231- node
232- for node in aesara .graph .ancestors (list (temp_logp_var_dict .values ()))
233- if (
234- node .owner
235- and isinstance (node .owner .op , RandomVariable )
236- and not isinstance (node .owner .op , SimulatorRV )
237- )
238- ]
239- if unexpected_rv_nodes :
240- raise ValueError (
241- f"Random variables detected in the logp graph: { unexpected_rv_nodes } .\n "
242- "This can happen when DensityDist logp or Interval transform functions "
243- "reference nonlocal variables."
244- )
245-
246244 # aeppl returns the logp for every single value term we provided to it. This includes
247245 # the extra values we plugged in above, so we filter those we actually wanted in the
248246 # same order they were given in.
249247 logp_var_dict = {}
250248 for value_var in rv_values .values ():
251249 logp_var_dict [value_var ] = temp_logp_var_dict [value_var ]
252250
251+ _check_no_rvs (list (logp_var_dict .values ()))
252+
253253 if scaling :
254254 for value_var in logp_var_dict .keys ():
255255 if value_var in rv_scalings :
@@ -263,6 +263,52 @@ def joint_logp(
263263 return logp_var
264264
265265
266+ def _joint_logp (
267+ rvs : Sequence [TensorVariable ],
268+ * ,
269+ rvs_to_values : Dict [TensorVariable , TensorVariable ],
270+ rvs_to_transforms : Dict [TensorVariable , RVTransform ],
271+ jacobian : bool = True ,
272+ rvs_to_total_sizes : Dict [TensorVariable , TOTAL_SIZE ],
273+ ** kwargs ,
274+ ) -> List [TensorVariable ]:
275+ """Thin wrapper around aeppl.factorized_joint_logprob, extended with PyMC specific
276+ concerns such as transforms, jacobian, and scaling"""
277+
278+ transform_rewrite = None
279+ values_to_transforms = {
280+ rvs_to_values [rv ]: transform
281+ for rv , transform in rvs_to_transforms .items ()
282+ if transform is not None
283+ }
284+ if values_to_transforms :
285+ # There seems to be an incorrect type hint in TransformValuesRewrite
286+ transform_rewrite = TransformValuesRewrite (values_to_transforms ) # type: ignore
287+
288+ temp_logp_terms = factorized_joint_logprob (
289+ rvs_to_values ,
290+ extra_rewrites = transform_rewrite ,
291+ use_jacobian = jacobian ,
292+ ** kwargs ,
293+ )
294+
295+ # aeppl returns the logp for every single value term we provided to it. This includes
296+ # the extra values we plugged in above, so we filter those we actually wanted in the
297+ # same order they were given in.
298+ logp_terms = {}
299+ for rv in rvs :
300+ value_var = rvs_to_values [rv ]
301+ logp_term = temp_logp_terms [value_var ]
302+ total_size = rvs_to_total_sizes .get (rv , None )
303+ if total_size is not None :
304+ scaling = _get_scaling (total_size , value_var .shape , value_var .ndim )
305+ logp_term *= scaling
306+ logp_terms [value_var ] = logp_term
307+
308+ _check_no_rvs (list (logp_terms .values ()))
309+ return list (logp_terms .values ())
310+
311+
266312def logp (rv : TensorVariable , value ) -> TensorVariable :
267313 """Return the log-probability graph of a Random Variable"""
268314
0 commit comments