-
Notifications
You must be signed in to change notification settings - Fork 75
Support for custom priors via Prior class #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b35001b
b7300e7
a60035e
367c922
a9f821c
91aee00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
import pytensor.tensor as pt | ||
import xarray as xr | ||
from arviz import r2_score | ||
from pymc_extras.prior import Prior | ||
|
||
from causalpy.utils import round_num | ||
|
||
|
@@ -68,7 +69,15 @@ class PyMCModel(pm.Model): | |
Inference data... | ||
""" | ||
|
||
def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | ||
@property | ||
def default_priors(self): | ||
return {} | ||
|
||
def __init__( | ||
self, | ||
sample_kwargs: Optional[Dict[str, Any]] = None, | ||
priors: dict[str, Any] | None = None, | ||
): | ||
""" | ||
:param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the | ||
:func:`pymc.sample` function. Defaults to an empty dictionary. | ||
|
@@ -77,6 +86,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): | |
self.idata = None | ||
self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} | ||
|
||
self.priors = {**self.default_priors, **(priors or {})} | ||
|
||
def build_model(self, X, y, coords) -> None: | ||
"""Build the model, must be implemented by subclass.""" | ||
raise NotImplementedError("This method must be implemented by a subclass") | ||
|
@@ -188,15 +199,15 @@ def print_row( | |
coeffs = az.extract(self.idata.posterior, var_names="beta") | ||
|
||
# Determine the width of the longest label | ||
max_label_length = max(len(name) for name in labels + ["sigma"]) | ||
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) | ||
|
||
for name in labels: | ||
coeff_samples = coeffs.sel(coeffs=name) | ||
print_row(max_label_length, name, coeff_samples, round_to) | ||
|
||
# Add coefficient for measurement std | ||
coeff_samples = az.extract(self.idata.posterior, var_names="sigma") | ||
name = "sigma" | ||
coeff_samples = az.extract(self.idata.posterior, var_names="y_hat_sigma") | ||
name = "y_hat_sigma" | ||
print_row(max_label_length, name, coeff_samples, round_to) | ||
|
||
|
||
|
@@ -237,6 +248,11 @@ class LinearRegression(PyMCModel): | |
Inference data... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"beta": Prior("Normal", mu=0, sigma=50, dims="coeffs"), | ||
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), | ||
} | ||
|
||
def build_model(self, X, y, coords): | ||
""" | ||
Defines the PyMC model | ||
|
@@ -245,10 +261,9 @@ def build_model(self, X, y, coords): | |
self.add_coords(coords) | ||
X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
y = pm.Data("y", y, dims="obs_ind") | ||
beta = pm.Normal("beta", 0, 50, dims="coeffs") | ||
sigma = pm.HalfNormal("sigma", 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sigma will not be in the model anymore but rather, y_hat_sigma based on the default name generation. Is that breaking change an issue? There is a workaround for this if needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is a big deal. I pushed a fix which make tests pass a9f821c (well looks like there is one failing doctest) |
||
beta = self.priors["beta"].create_variable("beta") | ||
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|
||
|
||
class WeightedSumFitter(PyMCModel): | ||
|
@@ -276,6 +291,10 @@ class WeightedSumFitter(PyMCModel): | |
Inference data... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"), | ||
} | ||
|
||
def build_model(self, X, y, coords): | ||
""" | ||
Defines the PyMC model | ||
|
@@ -286,9 +305,8 @@ def build_model(self, X, y, coords): | |
X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
y = pm.Data("y", y[:, 0], dims="obs_ind") | ||
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We definitely want a custom prior for the Dirichlet. I think the Dirichlet would be used always (or nearly always), but there are plenty of real world use cases where the user might want to change the hyper parameters (currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright. since it is function of data, we will have to handle differently |
||
sigma = pm.HalfNormal("sigma", 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same breaking change concern There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replied above |
||
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind") | ||
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind") | ||
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) | ||
|
||
|
||
class InstrumentalVariableRegression(PyMCModel): | ||
|
@@ -477,13 +495,17 @@ class PropensityScore(PyMCModel): | |
Inference... | ||
""" # noqa: W605 | ||
|
||
default_priors = { | ||
"b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), | ||
} | ||
|
||
def build_model(self, X, t, coords): | ||
"Defines the PyMC propensity model" | ||
with self: | ||
self.add_coords(coords) | ||
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) | ||
t_data = pm.Data("t", t.flatten(), dims="obs_ind") | ||
b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") | ||
b = self.priors["b"].create_variable("b") | ||
mu = pm.math.dot(X_data, b) | ||
p = pm.Deterministic("p", pm.math.invlogit(mu)) | ||
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ dependencies: | |
- seaborn>=0.11.2 | ||
- statsmodels | ||
- xarray>=v2022.11.0 | ||
- pymc-extras>=0.2.7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add
@property
decorator here? Or is that remembered from it being done in thePyMCModel
base class?Getting an Pylance warning:
Type "dict[str, Prior]" is not assignable to declared type "property"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What line of code bring that on? Maybe having a setter will help?