Skip to content

Commit 1e7adcb

Browse files
committed
Add Censored distributions
1 parent e8700e8 commit 1e7adcb

File tree

8 files changed

+255
-2
lines changed

8 files changed

+255
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
119119
- With `pm.Data(..., mutable=True/False)`, or by using `pm.MutableData` vs. `pm.ConstantData` one can now create `TensorConstant` data variables. They can be more performant and compatible in situations where a variable doesn't need to be changed via `pm.set_data()`. See [#5295](https://github.com/pymc-devs/pymc/pull/5295).
120120
- New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`.
121121
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
122+
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
122123
- ...
123124

124125

docs/source/api/distributions.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ Distributions
66

77
distributions/continuous
88
distributions/discrete
9-
distributions/logprob
109
distributions/multivariate
1110
distributions/mixture
12-
distributions/simulator
1311
distributions/timeseries
12+
distributions/censored
13+
distributions/simulator
1414
distributions/transforms
15+
distributions/logprob
1516
distributions/utilities
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
********
2+
Censored
3+
********
4+
5+
.. currentmodule:: pymc
6+
.. autosummary::
7+
:toctree: generated
8+
9+
Censored

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222

2323
from pymc.distributions.bound import Bound
24+
from pymc.distributions.censored import Censored
2425
from pymc.distributions.continuous import (
2526
AsymmetricLaplace,
2627
Beta,
@@ -187,6 +188,7 @@
187188
"Rice",
188189
"Moyal",
189190
"Simulator",
191+
"Censored",
190192
"CAR",
191193
"PolyaGamma",
192194
"logpt",

pymc/distributions/censored.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import aesara.tensor as at
15+
import numpy as np
16+
17+
from aesara.scalar import Clip
18+
from aesara.tensor import TensorVariable
19+
from aesara.tensor.random.op import RandomVariable
20+
21+
from pymc.distributions.distribution import SymbolicDistribution, _get_moment
22+
from pymc.util import check_dist_not_registered
23+
24+
25+
class Censored(SymbolicDistribution):
26+
r"""
27+
Censored distribution
28+
29+
The pdf of a censored distribution is
30+
31+
.. math::
32+
33+
\begin{cases}
34+
0 & \text{for } x < lower, \\
35+
\text{CDF}(lower, dist) & \text{for } x = lower, \\
36+
\text{PDF}(x, dist) & \text{for } lower < x < upper, \\
37+
1-\text{CDF}(upper, dist) & \text {for} x = upper, \\
38+
0 & \text{for } x > upper,
39+
\end{cases}
40+
41+
42+
Parameters
43+
----------
44+
dist: PyMC unnamed distribution
45+
PyMC distribution created via the `.dist()` API, which will be censored. This
46+
distribution must be univariate and have a logcdf method implemented.
47+
lower: float or None
48+
Lower (left) censoring point. If `None` the distribution will not be left censored
49+
upper: float or None
50+
Upper (right) censoring point. If `None`, the distribution will not be right censored.
51+
52+
53+
Examples
54+
--------
55+
.. code-block:: python
56+
57+
with pm.Model():
58+
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
59+
censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1)
60+
"""
61+
62+
@classmethod
63+
def dist(cls, dist, lower, upper, **kwargs):
64+
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
65+
raise ValueError(
66+
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
67+
)
68+
if dist.owner.op.ndim_supp > 0:
69+
raise NotImplementedError(
70+
"Censoring of multivariate distributions has not been implemented yet"
71+
)
72+
check_dist_not_registered(dist)
73+
return super().dist([dist, lower, upper], **kwargs)
74+
75+
@classmethod
76+
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
77+
if lower is None:
78+
lower = at.constant(-np.inf)
79+
if upper is None:
80+
upper = at.constant(np.inf)
81+
82+
# Censoring is achieved by clipping the base distribution between lower and upper
83+
rv_out = at.clip(dist, lower, upper)
84+
85+
# Reference nodes to facilitate identification in other classmethods, without
86+
# worring about possible dimshuffles
87+
rv_out.tag.dist = dist
88+
rv_out.tag.lower = lower
89+
rv_out.tag.upper = upper
90+
91+
if size is not None:
92+
rv_out = cls.change_size(rv_out, size)
93+
if rngs is not None:
94+
rv_out = cls.change_rngs(rv_out, rngs)
95+
96+
return rv_out
97+
98+
@classmethod
99+
def ndim_supp(cls, *dist_params):
100+
return 0
101+
102+
@classmethod
103+
def change_size(cls, rv, new_size):
104+
dist_node = rv.tag.dist.owner
105+
lower = rv.tag.lower
106+
upper = rv.tag.upper
107+
rng, old_size, dtype, *dist_params = dist_node.inputs
108+
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
109+
return cls.rv_op(new_dist, lower, upper)
110+
111+
@classmethod
112+
def change_rngs(cls, rv, new_rngs):
113+
(new_rng,) = new_rngs
114+
dist_node = rv.tag.dist.owner
115+
lower = rv.tag.lower
116+
upper = rv.tag.upper
117+
olg_rng, size, dtype, *dist_params = dist_node.inputs
118+
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
119+
return cls.rv_op(new_dist, lower, upper)
120+
121+
@classmethod
122+
def graph_rvs(cls, rv):
123+
return (rv.tag.dist,)
124+
125+
126+
@_get_moment.register(Clip)
127+
def get_moment_censored(op, rv, dist, lower, upper):
128+
moment = at.switch(
129+
at.eq(lower, -np.inf),
130+
at.switch(
131+
at.isinf(upper),
132+
# lower = -inf, upper = inf
133+
0,
134+
# lower = -inf, upper = x
135+
upper - 1,
136+
),
137+
at.switch(
138+
at.eq(upper, np.inf),
139+
# lower = x, upper = inf
140+
lower + 1,
141+
# lower = x, upper = x
142+
(lower + upper) / 2,
143+
),
144+
)
145+
moment = at.full_like(dist, moment)
146+
return moment

pymc/distributions/distribution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aeppl.logprob import _logcdf, _logprob
2727
from aesara import tensor as at
2828
from aesara.tensor.basic import as_tensor_variable
29+
from aesara.tensor.elemwise import Elemwise
2930
from aesara.tensor.random.op import RandomVariable
3031
from aesara.tensor.random.var import RandomStateSharedVariable
3132
from aesara.tensor.var import TensorVariable
@@ -628,6 +629,12 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
628629
return _get_moment(rv.owner.op, rv, *rv.owner.inputs).astype(rv.dtype)
629630

630631

632+
@_get_moment.register(Elemwise)
633+
def _get_moment_elemwise(op, rv, *dist_params):
634+
"""For Elemwise Ops, dispatch on respective scalar_op"""
635+
return _get_moment(op.scalar_op, rv, *dist_params)
636+
637+
631638
class Discrete(Distribution):
632639
"""Base class for discrete distributions"""
633640

pymc/tests/test_distributions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,3 +3275,73 @@ def logp(value, mu):
32753275
).shape
32763276
== to_tuple(size)
32773277
)
3278+
3279+
3280+
class TestCensored:
3281+
@pytest.mark.parametrize("censored", (False, True))
3282+
def test_censored_workflow(self, censored):
3283+
# Based on pymc-examples/censored_data
3284+
rng = np.random.default_rng(1234)
3285+
size = 500
3286+
true_mu = 13.0
3287+
true_sigma = 5.0
3288+
3289+
# Set censoring limits
3290+
low = 3.0
3291+
high = 16.0
3292+
3293+
# Draw censored samples
3294+
data = rng.normal(true_mu, true_sigma, size)
3295+
data[data <= low] = low
3296+
data[data >= high] = high
3297+
3298+
with pm.Model(rng_seeder=17092021) as m:
3299+
mu = pm.Normal(
3300+
"mu",
3301+
mu=((high - low) / 2) + low,
3302+
sigma=(high - low) / 2.0,
3303+
initval="moment",
3304+
)
3305+
sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0, initval="moment")
3306+
observed = pm.Censored(
3307+
"observed",
3308+
pm.Normal.dist(mu=mu, sigma=sigma),
3309+
lower=low if censored else None,
3310+
upper=high if censored else None,
3311+
observed=data,
3312+
)
3313+
3314+
prior_pred = pm.sample_prior_predictive()
3315+
posterior = pm.sample(tune=500, draws=500)
3316+
posterior_pred = pm.sample_posterior_predictive(posterior)
3317+
3318+
expected = True if censored else False
3319+
assert (9 < prior_pred.prior_predictive.mean() < 10) == expected
3320+
assert (13 < posterior.posterior["mu"].mean() < 14) == expected
3321+
assert (4.5 < posterior.posterior["sigma"].mean() < 5.5) == expected
3322+
assert (12 < posterior_pred.posterior_predictive.mean() < 13) == expected
3323+
3324+
def test_censored_invalid_dist(self):
3325+
with pm.Model():
3326+
invalid_dist = pm.Normal
3327+
with pytest.raises(
3328+
ValueError,
3329+
match=r"Censoring dist must be a distribution created via the",
3330+
):
3331+
x = pm.Censored("x", invalid_dist, lower=None, upper=None)
3332+
3333+
with pm.Model():
3334+
mv_dist = pm.Dirichlet.dist(a=[1, 1, 1])
3335+
with pytest.raises(
3336+
NotImplementedError,
3337+
match="Censoring of multivariate distributions has not been implemented yet",
3338+
):
3339+
x = pm.Censored("x", mv_dist, lower=None, upper=None)
3340+
3341+
with pm.Model():
3342+
registered_dist = pm.Normal("dist")
3343+
with pytest.raises(
3344+
ValueError,
3345+
match="The dist dist was already registered in the current model",
3346+
):
3347+
x = pm.Censored("x", registered_dist, lower=None, upper=None)

pymc/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,20 @@ def cf(self):
332332
return cf
333333

334334
return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f)
335+
336+
337+
def check_dist_not_registered(dist, model=None):
338+
"""Check that a dist is not registered in the model already"""
339+
from pymc.model import modelcontext
340+
341+
try:
342+
model = modelcontext(None)
343+
except TypeError:
344+
pass
345+
else:
346+
if dist in model.basic_RVs:
347+
raise ValueError(
348+
f"The dist {dist} was already registered in the current model.\n"
349+
f"You should use an unregistered (unnamed) distribution created via "
350+
f"the `.dist()` API instead, such as:\n`dist=pm.Normal.dist(0, 1)`"
351+
)

0 commit comments

Comments
 (0)