Skip to content

NUTS sampler gets 'stuck' for very long periods #776

@FedericoV

Description

@FedericoV

Hi!

Sorry to keep opening issues - but just noticed this:

Working on a hierarchical model, again very similar to the standard model described by Thomas Wiecki:

        with pm.Model() as hierarchical_model:
            mu_b = pm.Normal('mu_beta', mu=0., sd=100**2)
            sigma_b = pm.Uniform('sigma_beta', lower=0, upper=100)

            b = pm.Normal('beta', mu=mu_b, sd=sigma_b, shape=3)

            # Model error prior
            eps = pm.Uniform('eps', lower = 0, upper = 2)

            # Linear model
            enrich_est = b[gs_code] * del_idx
            # Check to make sure this is right.

            # Data likelihood
            enrich_like = pm.Normal('enrich_like', mu=enrich_est, sd=eps, observed=tip_fluorescence)

        with hierarchical_model:
            start = pm.find_MAP()
            step = pm.NUTS(scaling=start)
            hierarchical_trace = pm.sample(2000, step, start=start, progressbar=True)

After a few looops, that mostly took this timing:

[-----------------100%-----------------] 2001 of 2000 complete in 63.5 sec

It's currently stuck like so:

 [-                 3%                  ] 62 of 2000 complete in 11487.3 sec

I suspect this is largely because of this line:

 b[gs_code] * del_idx

del_idx is a boolean array of length n, while gs_code is an array with 3 possible values (0, 1, 2), and the values are quite imbalanced:

print ((gs_code == 0) * del_idx).sum()
print ((gs_code == 1) * del_idx).sum()
print ((gs_code == 2) * del_idx).sum()
print len(gs_code)
104
16
38
4925

I also had to simplify the model quite a bit, because when I was fitting an alpha term independently (as in the original Wiecki notebook) the iteration times were incredibly long. So I just did some independent mean centering for each condition as a pre-processing step - which is of course far less precise.

        for gs in [0, 1, 2]:
            wt_and_gs_code = np.logical_and(wt_idx, gs_code == gs)
            # Get cells that are wild type and in a specific growth stage
            m = tip_fluorescence[wt_and_gs_code].mean()
            # Get mean of wild type population in a specific growth stage
            tip_fluorescence[gs_code == gs] -= m
            # Substract mean from all cells in a specific growth stage

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions