Skip to content

Support for custom priors via Prior class #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

williambdean
Copy link
Contributor

@williambdean williambdean commented Jun 16, 2025

This using pymc-extras to allow for custom priors. Might need to adjust the remaining a bit to work.

@@ -245,10 +259,9 @@ def build_model(self, X, y, coords):
self.add_coords(coords)
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y, dims="obs_ind")
beta = pm.Normal("beta", 0, 50, dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sigma will not be in the model anymore but rather, y_hat_sigma based on the default name generation. Is that breaking change an issue? There is a workaround for this if needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a big deal. I pushed a fix which make tests pass a9f821c (well looks like there is one failing doctest)

@@ -286,9 +303,8 @@ def build_model(self, X, y, coords):
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y[:, 0], dims="obs_ind")
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same breaking change concern

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above

Copy link
Collaborator

@drbenvincent drbenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very cool!

As far as I can tell, these changes are currently entirely invisible to the end user, right?

Before we do a new release I will follow up with another PR that adds some documentation. This will most likely be in the form of just giving a couple of worked examples

Main request is to add custom-ness to the Dirichlet for the WeightedSumFitter class. Like I say in a comment, I mostly see people wanting to customise the hyperparams, not the distribution itself.

@@ -237,6 +248,11 @@ class LinearRegression(PyMCModel):
Inference data...
""" # noqa: W605

default_priors = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add @property decorator here? Or is that remembered from it being done in the PyMCModel base class?

Getting an Pylance warning: Type "dict[str, Prior]" is not assignable to declared type "property"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What line of code bring that on? Maybe having a setter will help?

@@ -245,10 +259,9 @@ def build_model(self, X, y, coords):
self.add_coords(coords)
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y, dims="obs_ind")
beta = pm.Normal("beta", 0, 50, dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a big deal. I pushed a fix which make tests pass a9f821c (well looks like there is one failing doctest)

@@ -286,9 +303,8 @@ def build_model(self, X, y, coords):
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y[:, 0], dims="obs_ind")
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
sigma = pm.HalfNormal("sigma", 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above

@@ -286,9 +305,8 @@ def build_model(self, X, y, coords):
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
y = pm.Data("y", y[:, 0], dims="obs_ind")
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely want a custom prior for the Dirichlet. I think the Dirichlet would be used always (or nearly always), but there are plenty of real world use cases where the user might want to change the hyper parameters (currently a=np.ones(n_predictors)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. since it is function of data, we will have to handle differently

Copy link

codecov bot commented Jun 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.42%. Comparing base (a626c1e) to head (91aee00).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #488      +/-   ##
==========================================
+ Coverage   94.40%   94.42%   +0.01%     
==========================================
  Files          29       29              
  Lines        2075     2081       +6     
==========================================
+ Hits         1959     1965       +6     
  Misses        116      116              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@drbenvincent
Copy link
Collaborator

tests pass :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants