diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index beec847e..3ab18968 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment): Intercept -0.5, 94% HDI [-1, 0.2] C(group)[T.1] 2, 94% HDI [2, 2] pre 1, 94% HDI [1, 1] - sigma 0.5, 94% HDI [0.5, 0.6] + y_hat_sigma 0.5, 94% HDI [0.5, 0.6] """ supports_ols = False diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index ea380c1a..f95b6371 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -22,6 +22,7 @@ import pytensor.tensor as pt import xarray as xr from arviz import r2_score +from pymc_extras.prior import Prior from causalpy.utils import round_num @@ -68,7 +69,15 @@ class PyMCModel(pm.Model): Inference data... """ - def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): + @property + def default_priors(self): + return {} + + def __init__( + self, + sample_kwargs: Optional[Dict[str, Any]] = None, + priors: dict[str, Any] | None = None, + ): """ :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the :func:`pymc.sample` function. Defaults to an empty dictionary. @@ -77,6 +86,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): self.idata = None self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} + self.priors = {**self.default_priors, **(priors or {})} + def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" raise NotImplementedError("This method must be implemented by a subclass") @@ -188,15 +199,15 @@ def print_row( coeffs = az.extract(self.idata.posterior, var_names="beta") # Determine the width of the longest label - max_label_length = max(len(name) for name in labels + ["sigma"]) + max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) for name in labels: coeff_samples = coeffs.sel(coeffs=name) print_row(max_label_length, name, coeff_samples, round_to) # Add coefficient for measurement std - coeff_samples = az.extract(self.idata.posterior, var_names="sigma") - name = "sigma" + coeff_samples = az.extract(self.idata.posterior, var_names="y_hat_sigma") + name = "y_hat_sigma" print_row(max_label_length, name, coeff_samples, round_to) @@ -237,6 +248,11 @@ class LinearRegression(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -245,10 +261,9 @@ def build_model(self, X, y, coords): self.add_coords(coords) X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims="obs_ind") - beta = pm.Normal("beta", 0, 50, dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class WeightedSumFitter(PyMCModel): @@ -276,6 +291,10 @@ class WeightedSumFitter(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -286,9 +305,8 @@ def build_model(self, X, y, coords): X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y[:, 0], dims="obs_ind") beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") - sigma = pm.HalfNormal("sigma", 1) mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") - pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class InstrumentalVariableRegression(PyMCModel): @@ -477,13 +495,17 @@ class PropensityScore(PyMCModel): Inference... """ # noqa: W605 + default_priors = { + "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), + } + def build_model(self, X, t, coords): "Defines the PyMC propensity model" with self: self.add_coords(coords) X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) t_data = pm.Data("t", t.flatten(), dims="obs_ind") - b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") + b = self.priors["b"].create_variable("b") mu = pm.math.dot(X_data, b) p = pm.Deterministic("p", pm.math.invlogit(mu)) pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 9975f47a..4a908d60 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 94.9% + interrogate: 94.5% @@ -12,8 +12,8 @@ interrogate interrogate - 94.9% - 94.9% + 94.5% + 94.5% diff --git a/environment.yml b/environment.yml index 02b7f920..2bc8ed20 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 + - pymc-extras>=0.2.7 diff --git a/pyproject.toml b/pyproject.toml index 29f86277..bcc4bc7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", + "pymc-extras>=0.2.7", ] # List additional groups of dependencies here (e.g. development dependencies). Users