|
17 | 17 | import pytest |
18 | 18 | import scipy.stats.distributions as sp |
19 | 19 |
|
| 20 | +from aeppl.abstract import get_measurable_outputs |
20 | 21 | from aesara.graph.basic import ancestors |
21 | 22 | from aesara.tensor.random.op import RandomVariable |
22 | 23 | from aesara.tensor.subtensor import ( |
|
32 | 33 | from pymc.aesaraf import floatX, walk_model |
33 | 34 | from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform |
34 | 35 | from pymc.distributions.discrete import Bernoulli |
35 | | -from pymc.distributions.logprob import joint_logpt, logcdf, logp |
| 36 | +from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp |
36 | 37 | from pymc.model import Model, Potential |
37 | 38 | from pymc.tests.helpers import select_by_precision |
38 | 39 |
|
@@ -227,3 +228,38 @@ def test_unexpected_rvs(): |
227 | 228 |
|
228 | 229 | with pytest.raises(ValueError, match="^Random variables detected in the logp graph"): |
229 | 230 | model.logpt() |
| 231 | + |
| 232 | + |
| 233 | +def test_ignore_logprob_basic(): |
| 234 | + x = Normal.dist() |
| 235 | + (measurable_x_out,) = get_measurable_outputs(x.owner.op, x.owner) |
| 236 | + assert measurable_x_out is x.owner.outputs[1] |
| 237 | + |
| 238 | + new_x = ignore_logprob(x) |
| 239 | + assert new_x is not x |
| 240 | + assert isinstance(new_x.owner.op, Normal) |
| 241 | + assert type(new_x.owner.op).__name__ == "UnmeasurableNormalRV" |
| 242 | + # Confirm that it does not have measurable output |
| 243 | + assert get_measurable_outputs(new_x.owner.op, new_x.owner) is None |
| 244 | + |
| 245 | + # Test that it will not clone a variable that is already unmeasurable |
| 246 | + new_new_x = ignore_logprob(new_x) |
| 247 | + assert new_new_x is new_x |
| 248 | + |
| 249 | + |
| 250 | +def test_ignore_logprob_model(): |
| 251 | + # logp that does not depend on input |
| 252 | + def logp(value, x): |
| 253 | + return value |
| 254 | + |
| 255 | + with Model() as m: |
| 256 | + x = Normal.dist() |
| 257 | + y = DensityDist("y", x, logp=logp) |
| 258 | + # Aeppl raises a KeyError when it finds an unexpected RV |
| 259 | + with pytest.raises(KeyError): |
| 260 | + joint_logpt([y], {y: y.type()}) |
| 261 | + |
| 262 | + with Model() as m: |
| 263 | + x = ignore_logprob(Normal.dist()) |
| 264 | + y = DensityDist("y", x, logp=logp) |
| 265 | + assert joint_logpt([y], {y: y.type()}) |
0 commit comments