Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def random(self, point=None, size=None):
tau, = draw_values([self.tau], point=point)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
else:
chol, = draw_values([self.chol], point=point)
chol, = draw_values([self.chol_cov], point=point)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)

samples = dist.random(point, size)
Expand Down
164 changes: 127 additions & 37 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Slice, CompoundStep)
from .plots.traceplot import traceplot
from .util import update_start_vals
from pymc3.step_methods.hmc import quadpotential
from pymc3.distributions import distribution
from tqdm import tqdm

import sys
Expand Down Expand Up @@ -118,20 +120,27 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
A step function or collection of functions. If there are variables
without a step methods, step methods for those variables will
be assigned automatically.
init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS', 'auto', None}
Initialization method to use. Only works for auto-assigned step methods.

* ADVI: Run ADVI to estimate starting points and diagonal covariance
matrix. If njobs > 1 it will sample starting points from the estimated
posterior, otherwise it will use the estimated posterior mean.
* ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
* MAP: Use the MAP as starting point.
* NUTS: Run NUTS to estimate starting points and covariance matrix. If
njobs > 1 it will sample starting points from the estimated posterior,
otherwise it will use the estimated posterior mean.
* auto : Auto-initialize, if possible. Currently only works when NUTS
is auto-assigned as step method (default).
* None: Do not initialize.
init : str
Initialization method to use for auto-assigned NUTS samplers.

* auto : Choose a default initialization method automatically.
Currently, this is `'advi+adapt_diag'`, but this can change in
the future. If you depend on the exact behaviour, choose an
initialization method explicitly.
* adapt_diag : Start with a identity mass matrix and then adapt
a diagonal based on the variance of the tuning samples.
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
mass matrix based on the sample variance of the tuning samples.
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
diagonal mass matrix based on the variance of the gradients
during tuning. This is **experimental** and might be removed
in a future release.
* advi : Run ADVI to estimate posterior mean and diagonal mass
matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map : Use the MAP as starting point. This is discouraged.
* nuts : Run NUTS and estimate posterior mean and mass matrix from
the trace.
n_init : int
Number of iterations of initializer
If 'ADVI', number of iterations, if 'nuts', number of draws.
Expand Down Expand Up @@ -220,9 +229,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,

draws += tune

if init is not None:
init = init.lower()

if nuts_kwargs is not None:
if step_kwargs is not None:
raise ValueError("Specify only one of step_kwargs and nuts_kwargs")
Expand All @@ -236,8 +242,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
pm._log.info('Auto-assigning NUTS sampler...')
args = step_kwargs if step_kwargs is not None else {}
args = args.get('nuts', {})
if init == 'auto':
init = 'ADVI'
start_, step = init_nuts(init=init, njobs=njobs, n_init=n_init,
model=model, random_seed=random_seed,
progressbar=progressbar, **args)
Expand Down Expand Up @@ -643,28 +647,42 @@ def sample_ppc_w(traces, samples=None, models=None, size=None, weights=None,
return {k: np.asarray(v) for k, v in ppc.items()}


def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
random_seed=-1, progressbar=True, **kwargs):
"""Initialize and sample from posterior of a continuous model.
"""Set up the mass matrix initialization for NUTS.

This is a convenience function. NUTS convergence and sampling speed is extremely
dependent on the choice of mass/scaling matrix. In our experience, using ADVI
to estimate a diagonal covariance matrix and using this as the scaling matrix
produces robust results over a wide class of continuous models.
NUTS convergence and sampling speed is extremely dependent on the
choice of mass/scaling matrix. This function implements different
methods for choosing or adapting the mass matrix.

Parameters
----------
init : str {'ADVI', 'ADVI_MAP', 'MAP', 'NUTS'}
init : str
Initialization method to use.
* ADVI : Run ADVI to estimate posterior mean and diagonal covariance matrix.
* ADVI_MAP: Initialize ADVI with MAP and use MAP as starting point.
* MAP : Use the MAP as starting point.
* NUTS : Run NUTS and estimate posterior mean and covariance matrix.

* auto : Choose a default initialization method automatically.
Currently, this is `'advi+adapt_diag'`, but this can change in
the future. If you depend on the exact behaviour, choose an
initialization method explicitly.
* adapt_diag : Start with a identity mass matrix and then adapt
a diagonal based on the variance of the tuning samples.
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
mass matrix based on the sample variance of the tuning samples.
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
diagonal mass matrix based on the variance of the gradients
during tuning. This is **experimental** and might be removed
in a future release.
* advi : Run ADVI to estimate posterior mean and diagonal mass
matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map : Use the MAP as starting point. This is discouraged.
* nuts : Run NUTS and estimate posterior mean and mass matrix from
the trace.
njobs : int
Number of parallel jobs to start.
n_init : int
Number of iterations of initializer
If 'ADVI', number of iterations, if 'metropolis', number of draws.
If 'ADVI', number of iterations, if 'nuts', number of draws.
model : Model (optional if in `with` context)
progressbar : bool
Whether or not to display a progressbar for advi sampling.
Expand All @@ -678,20 +696,83 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
nuts_sampler : pymc3.step_methods.NUTS
Instantiated and initialized NUTS sampler object
"""

model = pm.modelcontext(model)

pm._log.info('Initializing NUTS using {}...'.format(init))
vars = kwargs.get('vars', model.vars)
if set(vars) != set(model.vars):
raise ValueError('Must use init_nuts on all variables of a model.')
if not pm.model.all_continuous(vars):
raise ValueError('init_nuts can only be used for models with only '
'continuous variables.')

random_seed = int(np.atleast_1d(random_seed)[0])
if not isinstance(init, str):
raise TypeError('init must be a string.')

if init is not None:
init = init.lower()

if init == 'auto':
init = 'advi+adapt_diag'

pm._log.info('Initializing NUTS using {}...'.format(init))

random_seed = int(np.atleast_1d(random_seed)[0])

cb = [
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='absolute'),
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff='relative'),
]
if init == 'advi':

if init == 'adapt_diag':
start = []
for _ in range(njobs):
vals = distribution.draw_values(model.free_RVs)
point = {var.name: vals[i] for i, var in enumerate(model.free_RVs)}
start.append(point)
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
if njobs == 1:
start = start[0]
elif init == 'advi+adapt_diag_grad':
approx = pm.fit(
random_seed=random_seed,
n=n_init, method='advi', model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
start = approx.sample(draws=njobs)
start = list(start)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
mean = approx.gbij.rmap(approx.mean.get_value())
mean = model.dict_to_array(mean)
weight = 50
potential = quadpotential.QuadPotentialDiagAdaptGrad(
model.ndim, mean, cov, weight)
if njobs == 1:
start = start[0]
elif init == 'advi+adapt_diag':
approx = pm.fit(
random_seed=random_seed,
n=n_init, method='advi', model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
)
start = approx.sample(draws=njobs)
start = list(start)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
mean = approx.gbij.rmap(approx.mean.get_value())
mean = model.dict_to_array(mean)
weight = 50
potential = quadpotential.QuadPotentialDiagAdapt(
model.ndim, mean, cov, weight)
if njobs == 1:
start = start[0]
elif init == 'advi':
approx = pm.fit(
random_seed=random_seed,
n=n_init, method='advi', model=model,
Expand All @@ -700,8 +781,10 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
obj_optimizer=pm.adagrad_window
) # type: pm.MeanField
start = approx.sample(draws=njobs)
start = list(start)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
potential = quadpotential.QuadPotentialDiag(cov)
if njobs == 1:
start = start[0]
elif init == 'advi_map':
Expand All @@ -715,24 +798,31 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
obj_optimizer=pm.adagrad_window
)
start = approx.sample(draws=njobs)
start = list(start)
stds = approx.gbij.rmap(approx.std.eval())
cov = model.dict_to_array(stds) ** 2
potential = quadpotential.QuadPotentialDiag(cov)
if njobs == 1:
start = start[0]
elif init == 'map':
start = pm.find_MAP()
cov = pm.find_hessian(point=start)
start = [start] * njobs
potential = quadpotential.QuadPotentialFull(cov)
if njobs == 1:
start = start[0]
elif init == 'nuts':
init_trace = pm.sample(draws=n_init, step=pm.NUTS(),
tune=n_init // 2,
random_seed=random_seed)
cov = np.atleast_1d(pm.trace_cov(init_trace))
start = np.random.choice(init_trace, njobs)
start = list(np.random.choice(init_trace, njobs))
potential = quadpotential.QuadPotentialFull(cov)
if njobs == 1:
start = start[0]
else:
raise NotImplementedError('Initializer {} is not supported.'.format(init))

step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)
step = pm.NUTS(potential=potential, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


return start, step
14 changes: 8 additions & 6 deletions pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pymc3.tuning import guess_scaling
from pymc3.model import modelcontext, Point
from .quadpotential import quad_potential
from .quadpotential import quad_potential, QuadPotentialDiagAdapt
from pymc3.theanof import inputvars, make_shared_replacements, floatX
import numpy as np

Expand Down Expand Up @@ -42,12 +42,14 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
vars = inputvars(vars)

if scaling is None and potential is None:
varnames = [var.name for var in vars]
size = sum(v.size for k, v in model.test_point.items() if k in varnames)
scaling = floatX(np.ones(size))
size = sum(np.prod(var.dshape, dtype=int) for var in vars)
mean = floatX(np.zeros(size))
var = floatX(np.ones(size))
potential = QuadPotentialDiagAdapt(size, mean, var, 10)

if isinstance(scaling, dict):
scaling = guess_scaling(Point(scaling, model=model), model=model, vars=vars)
point = Point(scaling, model=model)
scaling = guess_scaling(point, model=model, vars=vars)

if scaling is not None and potential is not None:
raise ValueError("Can not specify both potential and scaling.")
Expand All @@ -56,7 +58,7 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
if potential is not None:
self.potential = potential
else:
self.potential = quad_potential(scaling, is_cov, as_cov=False)
self.potential = quad_potential(scaling, is_cov)

shared = make_shared_replacements(vars, model)
if theano_kwargs is None:
Expand Down
Loading