1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ # Contains code from Aeppl, Copyright (c) 2021-2022, Aesara Developers.
16+
1517# coding: utf-8
1618"""
1719A collection of common probability distributions for stochastic
2628import aesara .tensor as at
2729import numpy as np
2830
29- from aeppl .logprob import _logprob , logcdf
31+ from aeppl .logprob import _logprob , logcdf , logprob
3032from aesara .graph .basic import Apply , Variable
3133from aesara .graph .op import Op
3234from aesara .raise_op import Assert
@@ -311,9 +313,22 @@ def moment(rv, size, lower, upper):
311313 moment = at .full (size , moment )
312314 return moment
313315
316+ def logp (value , lower , upper ):
317+ res = at .switch (
318+ at .bitwise_and (at .ge (value , lower ), at .le (value , upper )),
319+ at .fill (value , - at .log (upper - lower )),
320+ - np .inf ,
321+ )
322+
323+ return check_parameters (
324+ res ,
325+ lower <= upper ,
326+ msg = "lower <= upper" ,
327+ )
328+
314329 def logcdf (value , lower , upper ):
315- return at .switch (
316- at .lt (value , lower ) | at . lt ( upper , lower ) ,
330+ res = at .switch (
331+ at .lt (value , lower ),
317332 - np .inf ,
318333 at .switch (
319334 at .lt (value , upper ),
@@ -322,6 +337,12 @@ def logcdf(value, lower, upper):
322337 ),
323338 )
324339
340+ return check_parameters (
341+ res ,
342+ lower <= upper ,
343+ msg = "lower <= upper" ,
344+ )
345+
325346
326347@_default_transform .register (Uniform )
327348def uniform_default_transform (op , rv ):
@@ -495,6 +516,14 @@ def moment(rv, size, mu, sigma):
495516 mu = at .full (size , mu )
496517 return mu
497518
519+ def logp (value , mu , sigma ):
520+ res = - 0.5 * at .pow ((value - mu ) / sigma , 2 ) - at .log (at .sqrt (2.0 * np .pi )) - at .log (sigma )
521+ return check_parameters (
522+ res ,
523+ sigma > 0 ,
524+ msg = "sigma > 0" ,
525+ )
526+
498527 def logcdf (value , mu , sigma ):
499528 return check_parameters (
500529 normal_lcdf (mu , sigma , value ),
@@ -780,6 +809,15 @@ def moment(rv, size, loc, sigma):
780809 moment = at .full (size , moment )
781810 return moment
782811
812+ def logp (value , loc , sigma ):
813+ res = - 0.5 * at .pow ((value - loc ) / sigma , 2 ) + at .log (at .sqrt (2.0 / np .pi )) - at .log (sigma )
814+ res = at .switch (at .ge (value , loc ), res , - np .inf )
815+ return check_parameters (
816+ res ,
817+ sigma > 0 ,
818+ msg = "sigma > 0" ,
819+ )
820+
783821 def logcdf (value , loc , sigma ):
784822 z = zvalue (value , mu = loc , sigma = sigma )
785823 logcdf = at .switch (
@@ -1079,6 +1117,20 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None):
10791117
10801118 return alpha , beta
10811119
1120+ def logp (value , alpha , beta ):
1121+ res = (
1122+ at .switch (at .eq (alpha , 1.0 ), 0.0 , (alpha - 1.0 ) * at .log (value ))
1123+ + at .switch (at .eq (beta , 1.0 ), 0.0 , (beta - 1.0 ) * at .log1p (- value ))
1124+ - (at .gammaln (alpha ) + at .gammaln (beta ) - at .gammaln (alpha + beta ))
1125+ )
1126+ res = at .switch (at .bitwise_and (at .ge (value , 0.0 ), at .le (value , 1.0 )), res , - np .inf )
1127+ return check_parameters (
1128+ res ,
1129+ alpha > 0 ,
1130+ beta > 0 ,
1131+ msg = "alpha > 0, beta > 0" ,
1132+ )
1133+
10821134 def logcdf (value , alpha , beta ):
10831135 logcdf = at .switch (
10841136 at .lt (value , 0 ),
@@ -1261,6 +1313,15 @@ def moment(rv, size, mu):
12611313 mu = at .full (size , mu )
12621314 return mu
12631315
1316+ def logp (value , mu ):
1317+ res = - at .log (mu ) - value / mu
1318+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
1319+ return check_parameters (
1320+ res ,
1321+ mu >= 0 ,
1322+ msg = "mu >= 0" ,
1323+ )
1324+
12641325 def logcdf (value , mu ):
12651326 lam = at .reciprocal (mu )
12661327 res = at .switch (
@@ -1334,6 +1395,14 @@ def moment(rv, size, mu, b):
13341395 mu = at .full (size , mu )
13351396 return mu
13361397
1398+ def logp (value , mu , b ):
1399+ res = - at .log (2 * b ) - at .abs (value - mu ) / b
1400+ return check_parameters (
1401+ res ,
1402+ b > 0 ,
1403+ msg = "b > 0" ,
1404+ )
1405+
13371406 def logcdf (value , mu , b ):
13381407 y = (value - mu ) / b
13391408
@@ -1524,6 +1593,20 @@ def moment(rv, size, mu, sigma):
15241593 mean = at .full (size , mean )
15251594 return mean
15261595
1596+ def logp (value , mu , sigma ):
1597+ res = (
1598+ - 0.5 * at .pow ((at .log (value ) - mu ) / sigma , 2 )
1599+ - 0.5 * at .log (2.0 * np .pi )
1600+ - at .log (sigma )
1601+ - at .log (value )
1602+ )
1603+ res = at .switch (at .gt (value , 0.0 ), res , - np .inf )
1604+ return check_parameters (
1605+ res ,
1606+ sigma > 0 ,
1607+ msg = "sigma > 0" ,
1608+ )
1609+
15271610 def logcdf (value , mu , sigma ):
15281611 res = at .switch (
15291612 at .le (value , 0 ),
@@ -1732,6 +1815,16 @@ def moment(rv, size, alpha, m):
17321815 median = at .full (size , median )
17331816 return median
17341817
1818+ def logp (value , alpha , m ):
1819+ res = at .log (alpha ) + logpow (m , alpha ) - logpow (value , alpha + 1.0 )
1820+ res = at .switch (at .ge (value , m ), res , - np .inf )
1821+ return check_parameters (
1822+ res ,
1823+ alpha > 0 ,
1824+ m > 0 ,
1825+ msg = "alpha > 0, m > 0" ,
1826+ )
1827+
17351828 def logcdf (value , alpha , m ):
17361829 arg = (m / value ) ** alpha
17371830
@@ -1819,6 +1912,14 @@ def moment(rv, size, alpha, beta):
18191912 alpha = at .full (size , alpha )
18201913 return alpha
18211914
1915+ def logp (value , alpha , beta ):
1916+ res = - at .log (np .pi ) - at .log (beta ) - at .log1p (at .pow ((value - alpha ) / beta , 2 ))
1917+ return check_parameters (
1918+ res ,
1919+ beta > 0 ,
1920+ msg = "beta > 0" ,
1921+ )
1922+
18221923 def logcdf (value , alpha , beta ):
18231924 res = at .log (0.5 + at .arctan ((value - alpha ) / beta ) / np .pi )
18241925 return check_parameters (
@@ -1879,6 +1980,15 @@ def moment(rv, size, loc, beta):
18791980 beta = at .full (size , beta )
18801981 return beta
18811982
1983+ def logp (value , loc , beta ):
1984+ res = at .log (2 ) + logprob (Cauchy .dist (loc , beta ), value )
1985+ res = at .switch (at .ge (value , loc ), res , - np .inf )
1986+ return check_parameters (
1987+ res ,
1988+ beta > 0 ,
1989+ msg = "beta > 0" ,
1990+ )
1991+
18821992 def logcdf (value , loc , beta ):
18831993 res = at .switch (
18841994 at .lt (value , loc ),
@@ -1990,6 +2100,17 @@ def moment(rv, size, alpha, inv_beta):
19902100 mean = at .full (size , mean )
19912101 return mean
19922102
2103+ def logp (value , alpha , inv_beta ):
2104+ beta = at .reciprocal (inv_beta )
2105+ res = - at .gammaln (alpha ) + logpow (beta , alpha ) - beta * value + logpow (value , alpha - 1 )
2106+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
2107+ return check_parameters (
2108+ res ,
2109+ alpha > 0 ,
2110+ beta > 0 ,
2111+ msg = "alpha > 0, beta > 0" ,
2112+ )
2113+
19932114 def logcdf (value , alpha , inv_beta ):
19942115 beta = at .reciprocal (inv_beta )
19952116 res = at .switch (
@@ -2091,6 +2212,16 @@ def _get_alpha_beta(cls, alpha, beta, mu, sigma):
20912212
20922213 return alpha , beta
20932214
2215+ def logp (value , alpha , beta ):
2216+ res = - at .gammaln (alpha ) + logpow (beta , alpha ) - beta / value + logpow (value , - alpha - 1 )
2217+ res = at .switch (at .ge (value , 0.0 ), res , - np .inf )
2218+ return check_parameters (
2219+ res ,
2220+ alpha > 0 ,
2221+ beta > 0 ,
2222+ msg = "alpha > 0, beta > 0" ,
2223+ )
2224+
20942225 def logcdf (value , alpha , beta ):
20952226 res = at .switch (
20962227 at .lt (value , 0 ),
@@ -2158,6 +2289,9 @@ def moment(rv, size, nu):
21582289 moment = at .full (size , moment )
21592290 return moment
21602291
2292+ def logp (value , nu ):
2293+ return logprob (Gamma .dist (alpha = nu / 2 , beta = 0.5 ), value )
2294+
21612295 def logcdf (value , nu ):
21622296 return logcdf (Gamma .dist (alpha = nu / 2 , beta = 0.5 ), value )
21632297
@@ -2586,6 +2720,15 @@ def moment(rv, size, mu, kappa):
25862720 mu = at .full (size , mu )
25872721 return mu
25882722
2723+ def logp (value , mu , kappa ):
2724+ res = kappa * at .cos (mu - value ) - at .log (2 * np .pi ) - at .log (at .i0 (kappa ))
2725+ res = at .switch (at .bitwise_and (at .ge (value , - np .pi ), at .le (value , np .pi )), res , - np .inf )
2726+ return check_parameters (
2727+ res ,
2728+ kappa > 0 ,
2729+ msg = "kappa > 0" ,
2730+ )
2731+
25892732
25902733class SkewNormalRV (RandomVariable ):
25912734 name = "skewnormal"
@@ -2771,6 +2914,20 @@ def moment(rv, size, lower, c, upper):
27712914 mean = at .full (size , mean )
27722915 return mean
27732916
2917+ def logp (value , lower , c , upper ):
2918+ res = at .switch (
2919+ at .lt (value , c ),
2920+ at .log (2 * (value - lower ) / ((upper - lower ) * (c - lower ))),
2921+ at .log (2 * (upper - value ) / ((upper - lower ) * (upper - c ))),
2922+ )
2923+ res = at .switch (at .bitwise_and (at .le (lower , value ), at .le (value , upper )), res , - np .inf )
2924+ return check_parameters (
2925+ res ,
2926+ lower <= c ,
2927+ c <= upper ,
2928+ msg = "lower <= c <= upper" ,
2929+ )
2930+
27742931 def logcdf (value , lower , c , upper ):
27752932 res = at .switch (
27762933 at .le (value , lower ),
@@ -2863,6 +3020,15 @@ def moment(rv, size, mu, beta):
28633020 mean = at .full (size , mean )
28643021 return mean
28653022
3023+ def logp (value , mu , beta ):
3024+ z = (value - mu ) / beta
3025+ res = - z - at .exp (- z ) - at .log (beta )
3026+ return check_parameters (
3027+ res ,
3028+ beta > 0 ,
3029+ msg = "beta > 0" ,
3030+ )
3031+
28663032 def logcdf (value , mu , beta ):
28673033 res = - at .exp (- (value - mu ) / beta )
28683034
@@ -3062,6 +3228,15 @@ def moment(rv, size, mu, s):
30623228 mu = at .full (size , mu )
30633229 return mu
30643230
3231+ def logp (value , mu , s ):
3232+ z = (value - mu ) / s
3233+ res = - z - at .log (s ) - 2.0 * at .log1p (at .exp (- z ))
3234+ return check_parameters (
3235+ res ,
3236+ s > 0 ,
3237+ msg = "s > 0" ,
3238+ )
3239+
30653240 def logcdf (value , mu , s ):
30663241 res = - at .log1pexp (- (value - mu ) / s )
30673242
0 commit comments