Skip to content

Commit 8fb942c

Browse files
authored
Merge pull request #2327 from aseyboldt/nuts-adapt
Implement mass matrix adaptation
2 parents 971db07 + fe3bc1f commit 8fb942c

File tree

8 files changed

+408
-85
lines changed

8 files changed

+408
-85
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def random(self, point=None, size=None):
344344
tau, = draw_values([self.tau], point=point)
345345
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
346346
else:
347-
chol, = draw_values([self.chol], point=point)
347+
chol, = draw_values([self.chol_cov], point=point)
348348
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
349349

350350
samples = dist.random(point, size)

pymc3/sampling.py

Lines changed: 127 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
Slice, CompoundStep)
1414
from .plots.traceplot import traceplot
1515
from .util import update_start_vals
16+
from pymc3.step_methods.hmc import quadpotential
17+
from pymc3.distributions import distribution
1618
from tqdm import tqdm
1719

1820
import sys
@@ -118,20 +120,27 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
118120
A step function or collection of functions. If there are variables
119121
without a step methods, step methods for those variables will
120122
be assigned automatically.
121-
init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS', 'auto', None}
122-
Initialization method to use. Only works for auto-assigned step methods.
123-
124-
* ADVI: Run ADVI to estimate starting points and diagonal covariance
125-
matrix. If njobs > 1 it will sample starting points from the estimated
126-
posterior, otherwise it will use the estimated posterior mean.
127-
* ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
128-
* MAP: Use the MAP as starting point.
129-
* NUTS: Run NUTS to estimate starting points and covariance matrix. If
130-
njobs > 1 it will sample starting points from the estimated posterior,
131-
otherwise it will use the estimated posterior mean.
132-
* auto : Auto-initialize, if possible. Currently only works when NUTS
133-
is auto-assigned as step method (default).
134-
* None: Do not initialize.
123+
init : str
124+
Initialization method to use for auto-assigned NUTS samplers.
125+
126+
* auto : Choose a default initialization method automatically.
127+
Currently, this is `'advi+adapt_diag'`, but this can change in
128+
the future. If you depend on the exact behaviour, choose an
129+
initialization method explicitly.
130+
* adapt_diag : Start with a identity mass matrix and then adapt
131+
a diagonal based on the variance of the tuning samples.
132+
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
133+
mass matrix based on the sample variance of the tuning samples.
134+
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
135+
diagonal mass matrix based on the variance of the gradients
136+
during tuning. This is **experimental** and might be removed
137+
in a future release.
138+
* advi : Run ADVI to estimate posterior mean and diagonal mass
139+
matrix.
140+
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
141+
* map : Use the MAP as starting point. This is discouraged.
142+
* nuts : Run NUTS and estimate posterior mean and mass matrix from
143+
the trace.
135144
n_init : int
136145
Number of iterations of initializer
137146
If 'ADVI', number of iterations, if 'nuts', number of draws.
@@ -220,9 +229,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
220229

221230
draws += tune
222231

223-
if init is not None:
224-
init = init.lower()
225-
226232
if nuts_kwargs is not None:
227233
if step_kwargs is not None:
228234
raise ValueError("Specify only one of step_kwargs and nuts_kwargs")
@@ -236,8 +242,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
236242
pm._log.info('Auto-assigning NUTS sampler...')
237243
args = step_kwargs if step_kwargs is not None else {}
238244
args = args.get('nuts', {})
239-
if init == 'auto':
240-
init = 'ADVI'
241245
start_, step = init_nuts(init=init, njobs=njobs, n_init=n_init,
242246
model=model, random_seed=random_seed,
243247
progressbar=progressbar, **args)
@@ -643,28 +647,42 @@ def sample_ppc_w(traces, samples=None, models=None, size=None, weights=None,
643647
return {k: np.asarray(v) for k, v in ppc.items()}
644648

645649

646-
def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
650+
def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
647651
random_seed=-1, progressbar=True, **kwargs):
648-
"""Initialize and sample from posterior of a continuous model.
652+
"""Set up the mass matrix initialization for NUTS.
649653
650-
This is a convenience function. NUTS convergence and sampling speed is extremely
651-
dependent on the choice of mass/scaling matrix. In our experience, using ADVI
652-
to estimate a diagonal covariance matrix and using this as the scaling matrix
653-
produces robust results over a wide class of continuous models.
654+
NUTS convergence and sampling speed is extremely dependent on the
655+
choice of mass/scaling matrix. This function implements different
656+
methods for choosing or adapting the mass matrix.
654657
655658
Parameters
656659
----------
657-
init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS'}
660+
init : str
658661
Initialization method to use.
659-
* ADVI : Run ADVI to estimate posterior mean and diagonal covariance matrix.
660-
* ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
661-
* MAP : Use the MAP as starting point.
662-
* NUTS : Run NUTS and estimate posterior mean and covariance matrix.
662+
663+
* auto : Choose a default initialization method automatically.
664+
Currently, this is `'advi+adapt_diag'`, but this can change in
665+
the future. If you depend on the exact behaviour, choose an
666+
initialization method explicitly.
667+
* adapt_diag : Start with a identity mass matrix and then adapt
668+
a diagonal based on the variance of the tuning samples.
669+
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
670+
mass matrix based on the sample variance of the tuning samples.
671+
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
672+
diagonal mass matrix based on the variance of the gradients
673+
during tuning. This is **experimental** and might be removed
674+
in a future release.
675+
* advi : Run ADVI to estimate posterior mean and diagonal mass
676+
matrix.
677+
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
678+
* map : Use the MAP as starting point. This is discouraged.
679+
* nuts : Run NUTS and estimate posterior mean and mass matrix from
680+
the trace.
663681
njobs : int
664682
Number of parallel jobs to start.
665683
n_init : int
666684
Number of iterations of initializer
667-
If 'ADVI', number of iterations, if 'metropolis', number of draws.
685+
If 'ADVI', number of iterations, if 'nuts', number of draws.
668686
model : Model (optional if in `with` context)
669687
progressbar : bool
670688
Whether or not to display a progressbar for advi sampling.
@@ -678,20 +696,83 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
678696
nuts_sampler : pymc3.step_methods.NUTS
679697
Instantiated and initialized NUTS sampler object
680698
"""
681-
682699
model = pm.modelcontext(model)
683700

684-
pm._log.info('Initializing NUTS using {}...'.format(init))
701+
vars = kwargs.get('vars', model.vars)
702+
if set(vars) != set(model.vars):
703+
raise ValueError('Must use init_nuts on all variables of a model.')
704+
if not pm.model.all_continuous(vars):
705+
raise ValueError('init_nuts can only be used for models with only '
706+
'continuous variables.')
685707

686-
random_seed = int(np.atleast_1d(random_seed)[0])
708+
if not isinstance(init, str):
709+
raise TypeError('init must be a string.')
687710

688711
if init is not None:
689712
init = init.lower()
713+
714+
if init == 'auto':
715+
init = 'advi+adapt_diag'
716+
717+
pm._log.info('Initializing NUTS using {}...'.format(init))
718+
719+
random_seed = int(np.atleast_1d(random_seed)[0])
720+
690721
cb = [
691722
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='absolute'),
692723
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='relative'),
693724
]
694-
if init == 'advi':
725+
726+
if init == 'adapt_diag':
727+
start = []
728+
for _ in range(njobs):
729+
vals = distribution.draw_values(model.free_RVs)
730+
point = {var.name: vals[i] for i, var in enumerate(model.free_RVs)}
731+
start.append(point)
732+
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
733+
var = np.ones_like(mean)
734+
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
735+
if njobs == 1:
736+
start = start[0]
737+
elif init == 'advi+adapt_diag_grad':
738+
approx = pm.fit(
739+
random_seed=random_seed,
740+
n=n_init, method='advi', model=model,
741+
callbacks=cb,
742+
progressbar=progressbar,
743+
obj_optimizer=pm.adagrad_window,
744+
)
745+
start = approx.sample(draws=njobs)
746+
start = list(start)
747+
stds = approx.gbij.rmap(approx.std.eval())
748+
cov = model.dict_to_array(stds) ** 2
749+
mean = approx.gbij.rmap(approx.mean.get_value())
750+
mean = model.dict_to_array(mean)
751+
weight = 50
752+
potential = quadpotential.QuadPotentialDiagAdaptGrad(
753+
model.ndim, mean, cov, weight)
754+
if njobs == 1:
755+
start = start[0]
756+
elif init == 'advi+adapt_diag':
757+
approx = pm.fit(
758+
random_seed=random_seed,
759+
n=n_init, method='advi', model=model,
760+
callbacks=cb,
761+
progressbar=progressbar,
762+
obj_optimizer=pm.adagrad_window,
763+
)
764+
start = approx.sample(draws=njobs)
765+
start = list(start)
766+
stds = approx.gbij.rmap(approx.std.eval())
767+
cov = model.dict_to_array(stds) ** 2
768+
mean = approx.gbij.rmap(approx.mean.get_value())
769+
mean = model.dict_to_array(mean)
770+
weight = 50
771+
potential = quadpotential.QuadPotentialDiagAdapt(
772+
model.ndim, mean, cov, weight)
773+
if njobs == 1:
774+
start = start[0]
775+
elif init == 'advi':
695776
approx = pm.fit(
696777
random_seed=random_seed,
697778
n=n_init, method='advi', model=model,
@@ -700,8 +781,10 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
700781
obj_optimizer=pm.adagrad_window
701782
) # type: pm.MeanField
702783
start = approx.sample(draws=njobs)
784+
start = list(start)
703785
stds = approx.gbij.rmap(approx.std.eval())
704786
cov = model.dict_to_array(stds) ** 2
787+
potential = quadpotential.QuadPotentialDiag(cov)
705788
if njobs == 1:
706789
start = start[0]
707790
elif init == 'advi_map':
@@ -715,24 +798,31 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
715798
obj_optimizer=pm.adagrad_window
716799
)
717800
start = approx.sample(draws=njobs)
801+
start = list(start)
718802
stds = approx.gbij.rmap(approx.std.eval())
719803
cov = model.dict_to_array(stds) ** 2
804+
potential = quadpotential.QuadPotentialDiag(cov)
720805
if njobs == 1:
721806
start = start[0]
722807
elif init == 'map':
723808
start = pm.find_MAP()
724809
cov = pm.find_hessian(point=start)
810+
start = [start] * njobs
811+
potential = quadpotential.QuadPotentialFull(cov)
812+
if njobs == 1:
813+
start = start[0]
725814
elif init == 'nuts':
726815
init_trace = pm.sample(draws=n_init, step=pm.NUTS(),
727816
tune=n_init // 2,
728817
random_seed=random_seed)
729818
cov = np.atleast_1d(pm.trace_cov(init_trace))
730-
start = np.random.choice(init_trace, njobs)
819+
start = list(np.random.choice(init_trace, njobs))
820+
potential = quadpotential.QuadPotentialFull(cov)
731821
if njobs == 1:
732822
start = start[0]
733823
else:
734824
raise NotImplementedError('Initializer {} is not supported.'.format(init))
735825

736-
step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)
826+
step = pm.NUTS(potential=potential, **kwargs)
737827

738828
return start, step

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pymc3.tuning import guess_scaling
55
from pymc3.model import modelcontext, Point
6-
from .quadpotential import quad_potential
6+
from .quadpotential import quad_potential, QuadPotentialDiagAdapt
77
from pymc3.theanof import inputvars, make_shared_replacements, floatX
88
import numpy as np
99

@@ -42,12 +42,14 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
4242
vars = inputvars(vars)
4343

4444
if scaling is None and potential is None:
45-
varnames = [var.name for var in vars]
46-
size = sum(v.size for k, v in model.test_point.items() if k in varnames)
47-
scaling = floatX(np.ones(size))
45+
size = sum(np.prod(var.dshape, dtype=int) for var in vars)
46+
mean = floatX(np.zeros(size))
47+
var = floatX(np.ones(size))
48+
potential = QuadPotentialDiagAdapt(size, mean, var, 10)
4849

4950
if isinstance(scaling, dict):
50-
scaling = guess_scaling(Point(scaling, model=model), model=model, vars=vars)
51+
point = Point(scaling, model=model)
52+
scaling = guess_scaling(point, model=model, vars=vars)
5153

5254
if scaling is not None and potential is not None:
5355
raise ValueError("Can not specify both potential and scaling.")
@@ -56,7 +58,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
5658
if potential is not None:
5759
self.potential = potential
5860
else:
59-
self.potential = quad_potential(scaling, is_cov, as_cov=False)
61+
self.potential = quad_potential(scaling, is_cov)
6062

6163
shared = make_shared_replacements(vars, model)
6264
if theano_kwargs is None:

0 commit comments

Comments
 (0)