Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions pymc3/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,28 @@ class ArrayOrdering(object):

def __init__(self, vars):
self.vmap = []
dim = 0
self._by_name = {}
size = 0

for var in vars:
slc = slice(dim, dim + var.dsize)
self.vmap.append(VarMap(str(var), slc, var.dshape, var.dtype))
dim += var.dsize
name = var.name
if name is None:
raise ValueError('Unnamed variable in ArrayOrdering.')
if name in self._by_name:
raise ValueError('Name of variable not unique: %s.' % name)
if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'):
raise ValueError('Shape of variable not known %s' % name)

slc = slice(size, size + var.dsize)
varmap = VarMap(name, slc, var.dshape, var.dtype)
self.vmap.append(varmap)
self._by_name[name] = varmap
size += var.dsize

self.size = size

self.dimensions = dim
def __getitem__(self, key):
return self._by_name[key]


class DictToArrayBijection(object):
Expand Down Expand Up @@ -58,7 +72,7 @@ def map(self, dpt):
----------
dpt : dict
"""
apt = np.empty(self.ordering.dimensions, dtype=self.array_dtype)
apt = np.empty(self.ordering.size, dtype=self.array_dtype)
for var, slc, _, _ in self.ordering.vmap:
apt[slc] = dpt[var].ravel()
return apt
Expand Down Expand Up @@ -125,7 +139,7 @@ def __init__(self, list_arrays, intype='numpy'):
dim += array.size
count += 1

self.dimensions = dim
self.size = dim


class ListToArrayBijection(object):
Expand Down Expand Up @@ -158,7 +172,7 @@ def fmap(self, list_arrays):
single array comprising all the input arrays
"""

array = np.empty(self.ordering.dimensions)
array = np.empty(self.ordering.size)
for list_ind, slc, _, _, _ in self.ordering.vmap:
array[slc] = list_arrays[list_ind].ravel()
return array
Expand Down
195 changes: 187 additions & 8 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,12 @@ def fastd2logp(self, vars=None):
def logpt(self):
"""Theano scalar of log-probability of the model"""
if getattr(self, 'total_size', None) is not None:
return tt.sum(self.logp_elemwiset) * self.scaling
logp = tt.sum(self.logp_elemwiset) * self.scaling
else:
return tt.sum(self.logp_elemwiset)
logp = tt.sum(self.logp_elemwiset)
if self.name is not None:
logp.name = '__logp_%s' % self.name
return logp


class InitContextMeta(type):
Expand Down Expand Up @@ -277,6 +280,173 @@ def tree_contains(self, item):
return dict.__contains__(self, item)


class ValueGradFunction(object):
"""Create a theano function that computes a value and its gradient.

Parameters
----------
cost : theano variable
The value that we compute with its gradient.
grad_vars : list of named theano variables or None
The arguments with respect to which the gradient is computed.
extra_args : list of named theano variables or None
Other arguments of the function that are assumed constant. They
are stored in shared variables and can be set using
`set_extra_values`.
dtype : str, default=theano.config.floatX
The dtype of the arrays.
casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
Casting rule for casting `grad_args` to the array dtype.
See `numpy.can_cast` for a description of the options.
Keep in mind that we cast the variables to the array *and*
back from the array dtype to the variable dtype.
kwargs
Extra arguments are passed on to `theano.function`.

Attributes
----------
size : int
The number of elements in the parameter array.
profile : theano profiling object or None
The profiling object of the theano function that computes value and
gradient. This is None unless `profile=True` was set in the
kwargs.
"""
def __init__(self, cost, grad_vars, extra_vars=None, dtype=None,
casting='no', **kwargs):
if extra_vars is None:
extra_vars = []

names = [arg.name for arg in grad_vars + extra_vars]
if any(name is None for name in names):
raise ValueError('Arguments must be named.')
if len(set(names)) != len(names):
raise ValueError('Names of the arguments are not unique.')

if cost.ndim > 0:
raise ValueError('Cost must be a scalar.')

self._grad_vars = grad_vars
self._extra_vars = extra_vars
self._extra_var_names = set(var.name for var in extra_vars)
self._cost = cost
self._ordering = ArrayOrdering(grad_vars)
self.size = self._ordering.size
self._extra_are_set = False
if dtype is None:
dtype = theano.config.floatX
self.dtype = dtype
for var in self._grad_vars:
if not np.can_cast(var.dtype, self.dtype, casting):
raise TypeError('Invalid dtype for variable %s. Can not '
'cast to %s with casting rule %s.'
% (var.name, self.dtype, casting))
if not np.issubdtype(var.dtype, float):
raise TypeError('Invalid dtype for variable %s. Must be '
'floating point but is %s.'
% (var.name, var.dtype))

givens = []
self._extra_vars_shared = {}
for var in extra_vars:
shared = theano.shared(var.tag.test_value, var.name + '_shared__')
self._extra_vars_shared[var.name] = shared
givens.append((var, shared))

self._vars_joined, self._cost_joined = self._build_joined(
self._cost, grad_vars, self._ordering.vmap)

grad = tt.grad(self._cost_joined, self._vars_joined)
grad.name = '__grad'

inputs = [self._vars_joined]

self._theano_function = theano.function(
inputs, [self._cost_joined, grad], givens=givens, **kwargs)

def set_extra_values(self, extra_vars):
self._extra_are_set = True
for var in self._extra_vars:
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])

def get_extra_values(self):
if not self._extra_are_set:
raise ValueError('Extra values are not set.')

return {var.name: self._extra_vars_shared[var.name].get_value()
for var in self._extra_vars}

def __call__(self, array, grad_out=None, extra_vars=None):
if extra_vars is not None:
self.set_extra_values(extra_vars)

if not self._extra_are_set:
raise ValueError('Extra values are not set.')

if array.shape != (self.size,):
raise ValueError('Invalid shape for array. Must be %s but is %s.'
% ((self.size,), array.shape))

if grad_out is None:
out = np.empty_like(array)
else:
out = grad_out

logp, dlogp = self._theano_function(array)
if grad_out is None:
return logp, dlogp
else:
out[...] = dlogp
return logp

@property
def profile(self):
"""Profiling information of the underlying theano function."""
return self._theano_function.profile

def dict_to_array(self, point):
"""Convert a dictionary with values for grad_vars to an array."""
array = np.empty(self.size, dtype=self.dtype)
for varmap in self._ordering.vmap:
array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype)
return array

def array_to_dict(self, array):
"""Convert an array to a dictionary containing the grad_vars."""
if array.shape != (self.size,):
raise ValueError('Array should have shape (%s,) but has %s'
% (self.size, array.shape))
if array.dtype != self.dtype:
raise ValueError('Array has invalid dtype. Should be %s but is %s'
% (self._dtype, self.dtype))
point = {}
for varmap in self._ordering.vmap:
data = array[varmap.slc].reshape(varmap.shp)
point[varmap.var] = data.astype(varmap.dtyp)

return point

def array_to_full_dict(self, array):
"""Convert an array to a dictionary with grad_vars and extra_vars."""
point = self.array_to_dict(array)
for name, var in self._extra_vars_shared.items():
point[name] = var.get_value()
return point

def _build_joined(self, cost, args, vmap):
args_joined = tt.vector('__args_joined')
args_joined.tag.test_value = np.zeros(self.size, dtype=self.dtype)

joined_slices = {}
for vmap in vmap:
sliced = args_joined[vmap.slc].reshape(vmap.shp)
sliced.name = vmap.var
joined_slices[vmap.var] = sliced

replace = {var: joined_slices[var.name] for var in args}
return args_joined, theano.clone(cost, replace=replace)


class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
"""Encapsulates the variables and likelihood factors of a model.

Expand Down Expand Up @@ -419,7 +589,6 @@ def bijection(self):
return bij

@property
@memoize
def dict_to_array(self):
return self.bijection.map

Expand All @@ -428,23 +597,34 @@ def ndim(self):
return sum(var.dsize for var in self.free_RVs)

@property
@memoize
def logp_array(self):
return self.bijection.mapf(self.fastlogp)

@property
@memoize
def dlogp_array(self):
vars = inputvars(self.cont_vars)
return self.bijection.mapf(self.fastdlogp(vars))

def logp_dlogp_function(self, grad_vars=None, **kwargs):
if grad_vars is None:
grad_vars = list(typefilter(self.free_RVs, continuous_types))
else:
for var in grad_vars:
if var.dtype not in continuous_types:
raise ValueError("Can only compute the gradient of "
"continuous types: %s" % var)
varnames = [var.name for var in grad_vars]
extra_vars = [var for var in self.free_RVs if var.name not in varnames]
return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)

@property
@memoize
def logpt(self):
"""Theano scalar of log-probability of the model"""
with self:
factors = [var.logpt for var in self.basic_RVs] + self.potentials
return tt.add(*map(tt.sum, factors))
logp = tt.add(*map(tt.sum, factors))
logp.name = '__logp'
return logp

@property
def varlogpt(self):
Expand Down Expand Up @@ -595,7 +775,6 @@ def __getitem__(self, key):
except KeyError:
raise e

@memoize
def makefn(self, outs, mode=None, *args, **kwargs):
"""Compiles a Theano function which returns `outs` and takes the variable
ancestors of `outs` as inputs.
Expand Down
24 changes: 24 additions & 0 deletions pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,30 @@ def step(self, point):
return bij.rmap(apoint)


class GradientSharedStep(BlockedStep):
def __init__(self, vars, model=None, blocked=True,
dtype=None, **theano_kwargs):
model = modelcontext(model)
self.vars = vars
self.blocked = blocked

self._logp_dlogp_func = model.logp_dlogp_function(
vars, dtype=dtype, **theano_kwargs)

def step(self, point):
self._logp_dlogp_func.set_extra_values(point)
array = self._logp_dlogp_func.dict_to_array(point)

if self.generates_stats:
apoint, stats = self.astep(array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just thinking for the future, this would be easier to reason about if self.astep always returned at least an empty dictionary for stats

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I wrote this I though I'd rather not change the interface for step methods, so that we don't break custom step methods. Now I guess we should probably treat that as private part of the api, so yes, I agree.

point = self._logp_dlogp_func.array_to_full_dict(apoint)
return point, stats
else:
apoint = self.astep(array)
point = self._logp_dlogp_func.array_to_full_dict(apoint)
return point


def metrop_select(mr, q, q0):
"""Perform rejection/acceptance step for Metropolis class samplers.

Expand Down
Loading