Skip to content

Commit 971db07

Browse files
ctm22396aseyboldt
authored andcommitted
Fixed error introduced by 6f58dbf (#2419)
* Fixed error introduced by 6f58dbf in which the scaling variable is initialized as an array instead of a dict. This commit maintains the intentions of the 6f58dbf commit by initializing 'scaling' to a dict of arrays of identity with type floatX. * Array of ones with shape of selected vars This commit is actually preserves the intention of the previous commit as it avoids the guess_scaling function. It still maintains that the shape of scaling must be the size of the specified variables instead of the whole model. * Removed unnecessary function imports to namespace that were previously used for dict to array bijection
1 parent fa3aa1b commit 971db07

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(self, vars=None, scaling=None, step_scale=0.25, is_cov=False,
4242
vars = inputvars(vars)
4343

4444
if scaling is None and potential is None:
45-
scaling = floatX(np.ones(model.dict_to_array(model.test_point).size))
45+
varnames = [var.name for var in vars]
46+
size = sum(v.size for k, v in model.test_point.items() if k in varnames)
47+
scaling = floatX(np.ones(size))
4648

4749
if isinstance(scaling, dict):
4850
scaling = guess_scaling(Point(scaling, model=model), model=model, vars=vars)

0 commit comments

Comments
 (0)