2626from pymc .distributions .logprob import logp
2727from pymc .distributions .shape_utils import to_tuple
2828from pymc .model import modelcontext
29+ from pymc .util import check_dist_not_registered
2930
3031__all__ = ["Bound" ]
3132
@@ -144,8 +145,9 @@ class Bound:
144145
145146 Parameters
146147 ----------
147- distribution: pymc distribution
148- Distribution to be transformed into a bounded distribution.
148+ dist: PyMC unnamed distribution
149+ Distribution to be transformed into a bounded distribution created via the
150+ `.dist()` API.
149151 lower: float or array like, optional
150152 Lower bound of the distribution.
151153 upper: float or array like, optional
@@ -156,15 +158,15 @@ class Bound:
156158 .. code-block:: python
157159
158160 with pm.Model():
159- normal_dist = Normal.dist(mu=0.0, sigma=1.0, initval=-0.5 )
160- negative_normal = pm.Bound(normal_dist, upper=0.0)
161+ normal_dist = pm. Normal.dist(mu=0.0, sigma=1.0)
162+ negative_normal = pm.Bound("negative_normal", normal_dist, upper=0.0)
161163
162164 """
163165
164166 def __new__ (
165167 cls ,
166168 name ,
167- distribution ,
169+ dist ,
168170 lower = None ,
169171 upper = None ,
170172 size = None ,
@@ -174,7 +176,7 @@ def __new__(
174176 ** kwargs ,
175177 ):
176178
177- cls ._argument_checks (distribution , ** kwargs )
179+ cls ._argument_checks (dist , ** kwargs )
178180
179181 if dims is not None :
180182 model = modelcontext (None )
@@ -185,12 +187,12 @@ def __new__(
185187 raise ValueError ("Given dims do not exist in model coordinates." )
186188
187189 lower , upper , initval = cls ._set_values (lower , upper , size , shape , initval )
188- distribution .tag .ignore_logprob = True
190+ dist .tag .ignore_logprob = True
189191
190- if isinstance (distribution .owner .op , Continuous ):
192+ if isinstance (dist .owner .op , Continuous ):
191193 res = _ContinuousBounded (
192194 name ,
193- [distribution , lower , upper ],
195+ [dist , lower , upper ],
194196 initval = floatX (initval ),
195197 size = size ,
196198 shape = shape ,
@@ -199,7 +201,7 @@ def __new__(
199201 else :
200202 res = _DiscreteBounded (
201203 name ,
202- [distribution , lower , upper ],
204+ [dist , lower , upper ],
203205 initval = intX (initval ),
204206 size = size ,
205207 shape = shape ,
@@ -210,28 +212,28 @@ def __new__(
210212 @classmethod
211213 def dist (
212214 cls ,
213- distribution ,
215+ dist ,
214216 lower = None ,
215217 upper = None ,
216218 size = None ,
217219 shape = None ,
218220 ** kwargs ,
219221 ):
220222
221- cls ._argument_checks (distribution , ** kwargs )
223+ cls ._argument_checks (dist , ** kwargs )
222224 lower , upper , initval = cls ._set_values (lower , upper , size , shape , initval = None )
223- distribution .tag .ignore_logprob = True
224- if isinstance (distribution .owner .op , Continuous ):
225+ dist .tag .ignore_logprob = True
226+ if isinstance (dist .owner .op , Continuous ):
225227 res = _ContinuousBounded .dist (
226- [distribution , lower , upper ],
228+ [dist , lower , upper ],
227229 size = size ,
228230 shape = shape ,
229231 ** kwargs ,
230232 )
231233 res .tag .test_value = floatX (initval )
232234 else :
233235 res = _DiscreteBounded .dist (
234- [distribution , lower , upper ],
236+ [dist , lower , upper ],
235237 size = size ,
236238 shape = shape ,
237239 ** kwargs ,
@@ -240,7 +242,7 @@ def dist(
240242 return res
241243
242244 @classmethod
243- def _argument_checks (cls , distribution , ** kwargs ):
245+ def _argument_checks (cls , dist , ** kwargs ):
244246 if "observed" in kwargs :
245247 raise ValueError (
246248 "Observed Bound distributions are not supported. "
@@ -249,34 +251,22 @@ def _argument_checks(cls, distribution, **kwargs):
249251 "with the cumulative probability function."
250252 )
251253
252- if not isinstance (distribution , TensorVariable ):
254+ if not isinstance (dist , TensorVariable ):
253255 raise ValueError (
254256 "Passing a distribution class to `Bound` is no longer supported.\n "
255257 "Please pass the output of a distribution instantiated via the "
256258 "`.dist()` API such as:\n "
257259 '`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
258260 )
259261
260- try :
261- model = modelcontext (None )
262- except TypeError :
263- pass
264- else :
265- if distribution in model .basic_RVs :
266- raise ValueError (
267- f"The distribution passed into `Bound` was already registered "
268- f"in the current model.\n You should pass an unregistered "
269- f"(unnamed) distribution created via the `.dist()` API, such as:\n "
270- f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
271- )
272-
273- if distribution .owner .op .ndim_supp != 0 :
262+ check_dist_not_registered (dist )
263+
264+ if dist .owner .op .ndim_supp != 0 :
274265 raise NotImplementedError ("Bounding of MultiVariate RVs is not yet supported." )
275266
276- if not isinstance (distribution .owner .op , (Discrete , Continuous )):
267+ if not isinstance (dist .owner .op , (Discrete , Continuous )):
277268 raise ValueError (
278- f"`distribution` { distribution } must be a Discrete or Continuous"
279- " distribution subclass"
269+ f"`distribution` { dist } must be a Discrete or Continuous" " distribution subclass"
280270 )
281271
282272 @classmethod
0 commit comments