From 86bac0884c2c23861aec84b05ec71466fbf1a963 Mon Sep 17 00:00:00 2001 From: jnetzel1 Date: Fri, 26 Sep 2025 15:16:18 +0200 Subject: [PATCH 1/2] added logit normal icdf and test --- pymc/distributions/continuous.py | 6 ++++++ tests/distributions/test_continuous.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index dc98ed3144..1379077d24 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -3700,6 +3700,12 @@ def logp(value, mu, sigma): msg="tau > 0", ) + def icdf(value, mu, sigma): + # F^{-1}_{LogitNormal}(q) = sigmoid( mu + sigma * Phi^{-1}(q) ) + # where Phi^{-1} is the Normal icdf + res = invlogit(icdf(Normal.dist(mu, sigma), value)) + res = check_icdf_value(res, value) + return check_icdf_parameters(res, sigma > 0, msg="sigma > 0") def _interpolated_argcdf(p, pdf, cdf, x): if np.prod(cdf.shape[:-1]) != 1 or np.prod(pdf.shape[:-1]) != 1 or np.prod(x.shape[:-1]) != 1: diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 7209382666..e1e9b467d5 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -872,6 +872,12 @@ def test_logitnormal(self): ), decimal=select_by_precision(float64=6, float32=1), ) + check_icdf( + pm.LogitNormal, + {"mu": R, "sigma": Rplus}, + lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)), + decimal=select_by_precision(float64=12, float32=5), + ) @pytest.mark.skipif( condition=(pytensor.config.floatX == "float32"), From b5dabcbfc5a3adb995506893d757b99cc3d367f3 Mon Sep 17 00:00:00 2001 From: jnetzel1 Date: Fri, 26 Sep 2025 16:04:03 +0200 Subject: [PATCH 2/2] fix ruff-format --- pymc/distributions/continuous.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 1379077d24..b4ca91057c 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -3707,6 +3707,7 @@ def icdf(value, mu, sigma): res = check_icdf_value(res, value) return check_icdf_parameters(res, sigma > 0, msg="sigma > 0") + def _interpolated_argcdf(p, pdf, cdf, x): if np.prod(cdf.shape[:-1]) != 1 or np.prod(pdf.shape[:-1]) != 1 or np.prod(x.shape[:-1]) != 1: raise NotImplementedError(