|
39 | 39 | from collections import deque |
40 | 40 | from typing import Dict, List, Optional, Sequence, Union |
41 | 41 |
|
| 42 | +import numpy as np |
42 | 43 | import pytensor |
43 | 44 | import pytensor.tensor as pt |
44 | 45 |
|
45 | 46 | from pytensor import config |
46 | | -from pytensor.graph.basic import graph_inputs, io_toposort |
| 47 | +from pytensor.graph.basic import Variable, graph_inputs, io_toposort |
47 | 48 | from pytensor.graph.op import compute_test_value |
48 | 49 | from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter |
49 | 50 | from pytensor.tensor.random.op import RandomVariable |
50 | 51 | from pytensor.tensor.var import TensorVariable |
51 | | - |
52 | | -from pymc.logprob.abstract import _logprob, get_measurable_outputs |
53 | | -from pymc.logprob.abstract import logprob as logp_logprob |
| 52 | +from typing_extensions import TypeAlias |
| 53 | + |
| 54 | +from pymc.logprob.abstract import ( |
| 55 | + _icdf_helper, |
| 56 | + _logcdf_helper, |
| 57 | + _logprob, |
| 58 | + _logprob_helper, |
| 59 | + get_measurable_outputs, |
| 60 | +) |
54 | 61 | from pymc.logprob.rewriting import construct_ir_fgraph |
55 | 62 | from pymc.logprob.transforms import RVTransform, TransformValuesRewrite |
56 | 63 | from pymc.logprob.utils import rvs_to_value_vars |
57 | 64 |
|
| 65 | +TensorLike: TypeAlias = Union[Variable, float, np.ndarray] |
| 66 | + |
58 | 67 |
|
59 | | -def logp(rv: TensorVariable, value: TensorVariable, **kwargs) -> TensorVariable: |
| 68 | +def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
60 | 69 | """Return the log-probability graph of a Random Variable""" |
61 | 70 |
|
62 | 71 | value = pt.as_tensor_variable(value, dtype=rv.dtype) |
63 | 72 | try: |
64 | | - return logp_logprob(rv, value, **kwargs) |
| 73 | + return _logprob_helper(rv, value, **kwargs) |
65 | 74 | except NotImplementedError: |
66 | 75 | fgraph, _, _ = construct_ir_fgraph({rv: value}) |
67 | 76 | [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() |
68 | | - return logp_logprob(ir_rv, ir_value, **kwargs) |
| 77 | + return _logprob_helper(ir_rv, ir_value, **kwargs) |
| 78 | + |
| 79 | + |
| 80 | +def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
| 81 | + """Create a graph for the log-CDF of a Random Variable.""" |
| 82 | + value = pt.as_tensor_variable(value, dtype=rv.dtype) |
| 83 | + return _logcdf_helper(rv, value, **kwargs) |
| 84 | + |
| 85 | + |
| 86 | +def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
| 87 | + """Create a graph for the inverse CDF of a Random Variable.""" |
| 88 | + value = pt.as_tensor_variable(value) |
| 89 | + return _icdf_helper(rv, value, **kwargs) |
69 | 90 |
|
70 | 91 |
|
71 | 92 | def factorized_joint_logprob( |
|
0 commit comments