Skip to content
Merged
11 changes: 7 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,15 +2779,18 @@ def dist(cls, mu=0.0, sigma=None, nu=None, sd=None, *args, **kwargs):
sigma = at.as_tensor_variable(floatX(sigma))
nu = at.as_tensor_variable(floatX(nu))

# sd = sigma
# mean = mu + nu
# variance = (sigma ** 2) + (nu ** 2)

assert_negative_support(sigma, "sigma", "ExGaussian")
assert_negative_support(nu, "nu", "ExGaussian")

return super().dist([mu, sigma, nu], *args, **kwargs)

def get_moment(rv, size, mu, sigma, nu):
mu, nu, _ = at.broadcast_arrays(mu, nu, sigma)
moment = mu + nu
if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment

def logp(value, mu, sigma, nu):
"""
Calculate log-probability of ExGaussian distribution at specified value.
Expand Down
17 changes: 17 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ChiSquared,
Constant,
DiscreteUniform,
ExGaussian,
Exponential,
Flat,
Gamma,
Expand Down Expand Up @@ -541,6 +542,22 @@ def test_logistic_moment(mu, s, size, expected):
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"mu, nu, sigma, size, expected",
[
(1, 1, None, None, 2),
(1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)),
(1, 1, None, 5, np.full(5, 2)),
(1, np.arange(1, 6), None, None, np.arange(2, 7)),
(1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))),
],
)
def test_exgaussian_moment(mu, nu, sigma, size, expected):
with Model() as model:
ExGaussian("x", mu=mu, sigma=sigma, nu=nu, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"p, size, expected",
[
Expand Down