From 0636f4e6bb3cee4548bb08de2ff056d4a8d948a3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 13 Mar 2023 21:17:45 +0100 Subject: [PATCH] Allow logcdf and icdf inference --- pymc/logprob/basic.py | 18 +++++++++++++++--- pymc/logprob/transforms.py | 36 ++++++++++++++++++++++++++++++++++++ tests/logprob/test_basic.py | 4 ++-- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index f65e72f7c1..d0fc6bd3ae 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -80,13 +80,25 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: """Create a graph for the log-CDF of a Random Variable.""" value = pt.as_tensor_variable(value, dtype=rv.dtype) - return _logcdf_helper(rv, value, **kwargs) + try: + return _logcdf_helper(rv, value, **kwargs) + except NotImplementedError: + # Try to rewrite rv + fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) + [ir_rv] = fgraph.outputs + return _logcdf_helper(ir_rv, value, **kwargs) def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: """Create a graph for the inverse CDF of a Random Variable.""" - value = pt.as_tensor_variable(value) - return _icdf_helper(rv, value, **kwargs) + value = pt.as_tensor_variable(value, dtype=rv.dtype) + try: + return _icdf_helper(rv, value, **kwargs) + except NotImplementedError: + # Try to rewrite rv + fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) + [ir_rv] = fgraph.outputs + return _icdf_helper(ir_rv, value, **kwargs) def factorized_joint_logprob( diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 1190d7ce03..4f080f5c49 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -77,6 +77,10 @@ MeasurableElemwise, MeasurableVariable, _get_measurable_outputs, + _icdf, + _icdf_helper, + _logcdf, + _logcdf_helper, _logprob, _logprob_helper, ) @@ -387,6 +391,38 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) +@_logcdf.register(MeasurableTransform) +def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs): + """Compute the log-CDF graph for a `MeasurabeTransform`.""" + other_inputs = list(inputs) + measurable_input = other_inputs.pop(op.measurable_input_idx) + + backward_value = op.transform_elemwise.backward(value, *other_inputs) + + # Some transformations, like squaring may produce multiple backward values + if isinstance(backward_value, tuple): + raise NotImplementedError + + input_logcdf = _logcdf_helper(measurable_input, backward_value) + + # The jacobian is used to ensure a value in the supported domain was provided + jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) + + return pt.switch(pt.isnan(jacobian), -np.inf, input_logcdf) + + +@_icdf.register(MeasurableTransform) +def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs): + """Compute the inverse CDF graph for a `MeasurabeTransform`.""" + other_inputs = list(inputs) + measurable_input = other_inputs.pop(op.measurable_input_idx) + + input_icdf = _icdf_helper(measurable_input, value) + icdf = op.transform_elemwise.forward(input_icdf, *other_inputs) + + return icdf + + @node_rewriter([reciprocal]) def measurable_reciprocal_to_power(fgraph, node): """Convert reciprocal of `MeasurableVariable`s to power.""" diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 59abd5fecf..0546396dec 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -432,8 +432,8 @@ def test_probability_direct_dispatch(func, scipy_func): "func, scipy_func, test_value", [ (logp, "logpdf", 5.0), - pytest.param(logcdf, "logcdf", 5.0, marks=pytest.mark.xfail(raises=NotImplementedError)), - pytest.param(icdf, "ppf", 0.7, marks=pytest.mark.xfail(raises=NotImplementedError)), + (logcdf, "logcdf", 5.0), + (icdf, "ppf", 0.7), ], ) def test_probability_inference(func, scipy_func, test_value):