-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement mass matrix adaptation #2327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pymc3/sampling.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Should we actually use this for tuning (unrelated to init) in NUTS? |
dffba11
to
fec319e
Compare
I think this is ready to be merged now. |
pymc3/sampling.py
Outdated
return {k: np.asarray(v) for k, v in ppc.items()} | ||
|
||
|
||
def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldnt init='auto' as default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, fixed now.
It was the default in pm.sample
.
This implements adaptation for diagonal mass matrices during tuning.
sample
gets a new option forinit
, namelyadvi+adapt_diag
, which first runs advi to get an initial estimate for the mass matrix and then further adapts it using the covariance of the tuning samples.The approach for computing adapting the mass matrix differs a bit from that in stan. Stan keeps the mass matrix constant within a window and then uses the variances in the latest window as mass matrix for the next window.
In this PR I use an online algorithm for computing the variance and change it continuously. It starts with the estimate from advi with a weight of 30 and updates the variance for each new sample. After 200 samples it starts computing new variances from ground up, but still uses the first variance for sampling. after 400 samples it switches to the second set of variances. This should prevent bad initialization from advi to ruin the whole trace. (see
quadpotential.QuadpotentialDiagAdapt.adapt
)Still needs a lot of unit tests/doc/tinkering and we should compare it to the stan approach carefully.
It would probably also be an option to change weights over time and therefore achieve a similar effect as using windows.