Skip to content

Commit 61ba6ee

Browse files
authored
Merge pull request #2345 from aseyboldt/no-theano-nuts
NUTS refactoring towards GPU support
2 parents d06260d + cf785bc commit 61ba6ee

18 files changed

+596
-573
lines changed

pymc3/blocking.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,28 @@ class ArrayOrdering(object):
2323

2424
def __init__(self, vars):
2525
self.vmap = []
26-
dim = 0
26+
self._by_name = {}
27+
size = 0
2728

2829
for var in vars:
29-
slc = slice(dim, dim + var.dsize)
30-
self.vmap.append(VarMap(str(var), slc, var.dshape, var.dtype))
31-
dim += var.dsize
30+
name = var.name
31+
if name is None:
32+
raise ValueError('Unnamed variable in ArrayOrdering.')
33+
if name in self._by_name:
34+
raise ValueError('Name of variable not unique: %s.' % name)
35+
if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'):
36+
raise ValueError('Shape of variable not known %s' % name)
37+
38+
slc = slice(size, size + var.dsize)
39+
varmap = VarMap(name, slc, var.dshape, var.dtype)
40+
self.vmap.append(varmap)
41+
self._by_name[name] = varmap
42+
size += var.dsize
43+
44+
self.size = size
3245

33-
self.dimensions = dim
46+
def __getitem__(self, key):
47+
return self._by_name[key]
3448

3549

3650
class DictToArrayBijection(object):
@@ -58,7 +72,7 @@ def map(self, dpt):
5872
----------
5973
dpt : dict
6074
"""
61-
apt = np.empty(self.ordering.dimensions, dtype=self.array_dtype)
75+
apt = np.empty(self.ordering.size, dtype=self.array_dtype)
6276
for var, slc, _, _ in self.ordering.vmap:
6377
apt[slc] = dpt[var].ravel()
6478
return apt
@@ -125,7 +139,7 @@ def __init__(self, list_arrays, intype='numpy'):
125139
dim += array.size
126140
count += 1
127141

128-
self.dimensions = dim
142+
self.size = dim
129143

130144

131145
class ListToArrayBijection(object):
@@ -158,7 +172,7 @@ def fmap(self, list_arrays):
158172
single array comprising all the input arrays
159173
"""
160174

161-
array = np.empty(self.ordering.dimensions)
175+
array = np.empty(self.ordering.size)
162176
for list_ind, slc, _, _, _ in self.ordering.vmap:
163177
array[slc] = list_arrays[list_ind].ravel()
164178
return array

pymc3/model.py

Lines changed: 187 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,12 @@ def fastd2logp(self, vars=None):
184184
def logpt(self):
185185
"""Theano scalar of log-probability of the model"""
186186
if getattr(self, 'total_size', None) is not None:
187-
return tt.sum(self.logp_elemwiset) * self.scaling
187+
logp = tt.sum(self.logp_elemwiset) * self.scaling
188188
else:
189-
return tt.sum(self.logp_elemwiset)
189+
logp = tt.sum(self.logp_elemwiset)
190+
if self.name is not None:
191+
logp.name = '__logp_%s' % self.name
192+
return logp
190193

191194

192195
class InitContextMeta(type):
@@ -277,6 +280,173 @@ def tree_contains(self, item):
277280
return dict.__contains__(self, item)
278281

279282

283+
class ValueGradFunction(object):
284+
"""Create a theano function that computes a value and its gradient.
285+
286+
Parameters
287+
----------
288+
cost : theano variable
289+
The value that we compute with its gradient.
290+
grad_vars : list of named theano variables or None
291+
The arguments with respect to which the gradient is computed.
292+
extra_args : list of named theano variables or None
293+
Other arguments of the function that are assumed constant. They
294+
are stored in shared variables and can be set using
295+
`set_extra_values`.
296+
dtype : str, default=theano.config.floatX
297+
The dtype of the arrays.
298+
casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
299+
Casting rule for casting `grad_args` to the array dtype.
300+
See `numpy.can_cast` for a description of the options.
301+
Keep in mind that we cast the variables to the array *and*
302+
back from the array dtype to the variable dtype.
303+
kwargs
304+
Extra arguments are passed on to `theano.function`.
305+
306+
Attributes
307+
----------
308+
size : int
309+
The number of elements in the parameter array.
310+
profile : theano profiling object or None
311+
The profiling object of the theano function that computes value and
312+
gradient. This is None unless `profile=True` was set in the
313+
kwargs.
314+
"""
315+
def __init__(self, cost, grad_vars, extra_vars=None, dtype=None,
316+
casting='no', **kwargs):
317+
if extra_vars is None:
318+
extra_vars = []
319+
320+
names = [arg.name for arg in grad_vars + extra_vars]
321+
if any(name is None for name in names):
322+
raise ValueError('Arguments must be named.')
323+
if len(set(names)) != len(names):
324+
raise ValueError('Names of the arguments are not unique.')
325+
326+
if cost.ndim > 0:
327+
raise ValueError('Cost must be a scalar.')
328+
329+
self._grad_vars = grad_vars
330+
self._extra_vars = extra_vars
331+
self._extra_var_names = set(var.name for var in extra_vars)
332+
self._cost = cost
333+
self._ordering = ArrayOrdering(grad_vars)
334+
self.size = self._ordering.size
335+
self._extra_are_set = False
336+
if dtype is None:
337+
dtype = theano.config.floatX
338+
self.dtype = dtype
339+
for var in self._grad_vars:
340+
if not np.can_cast(var.dtype, self.dtype, casting):
341+
raise TypeError('Invalid dtype for variable %s. Can not '
342+
'cast to %s with casting rule %s.'
343+
% (var.name, self.dtype, casting))
344+
if not np.issubdtype(var.dtype, float):
345+
raise TypeError('Invalid dtype for variable %s. Must be '
346+
'floating point but is %s.'
347+
% (var.name, var.dtype))
348+
349+
givens = []
350+
self._extra_vars_shared = {}
351+
for var in extra_vars:
352+
shared = theano.shared(var.tag.test_value, var.name + '_shared__')
353+
self._extra_vars_shared[var.name] = shared
354+
givens.append((var, shared))
355+
356+
self._vars_joined, self._cost_joined = self._build_joined(
357+
self._cost, grad_vars, self._ordering.vmap)
358+
359+
grad = tt.grad(self._cost_joined, self._vars_joined)
360+
grad.name = '__grad'
361+
362+
inputs = [self._vars_joined]
363+
364+
self._theano_function = theano.function(
365+
inputs, [self._cost_joined, grad], givens=givens, **kwargs)
366+
367+
def set_extra_values(self, extra_vars):
368+
self._extra_are_set = True
369+
for var in self._extra_vars:
370+
self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
371+
372+
def get_extra_values(self):
373+
if not self._extra_are_set:
374+
raise ValueError('Extra values are not set.')
375+
376+
return {var.name: self._extra_vars_shared[var.name].get_value()
377+
for var in self._extra_vars}
378+
379+
def __call__(self, array, grad_out=None, extra_vars=None):
380+
if extra_vars is not None:
381+
self.set_extra_values(extra_vars)
382+
383+
if not self._extra_are_set:
384+
raise ValueError('Extra values are not set.')
385+
386+
if array.shape != (self.size,):
387+
raise ValueError('Invalid shape for array. Must be %s but is %s.'
388+
% ((self.size,), array.shape))
389+
390+
if grad_out is None:
391+
out = np.empty_like(array)
392+
else:
393+
out = grad_out
394+
395+
logp, dlogp = self._theano_function(array)
396+
if grad_out is None:
397+
return logp, dlogp
398+
else:
399+
out[...] = dlogp
400+
return logp
401+
402+
@property
403+
def profile(self):
404+
"""Profiling information of the underlying theano function."""
405+
return self._theano_function.profile
406+
407+
def dict_to_array(self, point):
408+
"""Convert a dictionary with values for grad_vars to an array."""
409+
array = np.empty(self.size, dtype=self.dtype)
410+
for varmap in self._ordering.vmap:
411+
array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype)
412+
return array
413+
414+
def array_to_dict(self, array):
415+
"""Convert an array to a dictionary containing the grad_vars."""
416+
if array.shape != (self.size,):
417+
raise ValueError('Array should have shape (%s,) but has %s'
418+
% (self.size, array.shape))
419+
if array.dtype != self.dtype:
420+
raise ValueError('Array has invalid dtype. Should be %s but is %s'
421+
% (self._dtype, self.dtype))
422+
point = {}
423+
for varmap in self._ordering.vmap:
424+
data = array[varmap.slc].reshape(varmap.shp)
425+
point[varmap.var] = data.astype(varmap.dtyp)
426+
427+
return point
428+
429+
def array_to_full_dict(self, array):
430+
"""Convert an array to a dictionary with grad_vars and extra_vars."""
431+
point = self.array_to_dict(array)
432+
for name, var in self._extra_vars_shared.items():
433+
point[name] = var.get_value()
434+
return point
435+
436+
def _build_joined(self, cost, args, vmap):
437+
args_joined = tt.vector('__args_joined')
438+
args_joined.tag.test_value = np.zeros(self.size, dtype=self.dtype)
439+
440+
joined_slices = {}
441+
for vmap in vmap:
442+
sliced = args_joined[vmap.slc].reshape(vmap.shp)
443+
sliced.name = vmap.var
444+
joined_slices[vmap.var] = sliced
445+
446+
replace = {var: joined_slices[var.name] for var in args}
447+
return args_joined, theano.clone(cost, replace=replace)
448+
449+
280450
class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
281451
"""Encapsulates the variables and likelihood factors of a model.
282452
@@ -419,7 +589,6 @@ def bijection(self):
419589
return bij
420590

421591
@property
422-
@memoize
423592
def dict_to_array(self):
424593
return self.bijection.map
425594

@@ -428,23 +597,34 @@ def ndim(self):
428597
return sum(var.dsize for var in self.free_RVs)
429598

430599
@property
431-
@memoize
432600
def logp_array(self):
433601
return self.bijection.mapf(self.fastlogp)
434602

435603
@property
436-
@memoize
437604
def dlogp_array(self):
438605
vars = inputvars(self.cont_vars)
439606
return self.bijection.mapf(self.fastdlogp(vars))
440607

608+
def logp_dlogp_function(self, grad_vars=None, **kwargs):
609+
if grad_vars is None:
610+
grad_vars = list(typefilter(self.free_RVs, continuous_types))
611+
else:
612+
for var in grad_vars:
613+
if var.dtype not in continuous_types:
614+
raise ValueError("Can only compute the gradient of "
615+
"continuous types: %s" % var)
616+
varnames = [var.name for var in grad_vars]
617+
extra_vars = [var for var in self.free_RVs if var.name not in varnames]
618+
return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
619+
441620
@property
442-
@memoize
443621
def logpt(self):
444622
"""Theano scalar of log-probability of the model"""
445623
with self:
446624
factors = [var.logpt for var in self.basic_RVs] + self.potentials
447-
return tt.add(*map(tt.sum, factors))
625+
logp = tt.add(*map(tt.sum, factors))
626+
logp.name = '__logp'
627+
return logp
448628

449629
@property
450630
def varlogpt(self):
@@ -595,7 +775,6 @@ def __getitem__(self, key):
595775
except KeyError:
596776
raise e
597777

598-
@memoize
599778
def makefn(self, outs, mode=None, *args, **kwargs):
600779
"""Compiles a Theano function which returns `outs` and takes the variable
601780
ancestors of `outs` as inputs.

pymc3/step_methods/arraystep.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,30 @@ def step(self, point):
157157
return bij.rmap(apoint)
158158

159159

160+
class GradientSharedStep(BlockedStep):
161+
def __init__(self, vars, model=None, blocked=True,
162+
dtype=None, **theano_kwargs):
163+
model = modelcontext(model)
164+
self.vars = vars
165+
self.blocked = blocked
166+
167+
self._logp_dlogp_func = model.logp_dlogp_function(
168+
vars, dtype=dtype, **theano_kwargs)
169+
170+
def step(self, point):
171+
self._logp_dlogp_func.set_extra_values(point)
172+
array = self._logp_dlogp_func.dict_to_array(point)
173+
174+
if self.generates_stats:
175+
apoint, stats = self.astep(array)
176+
point = self._logp_dlogp_func.array_to_full_dict(apoint)
177+
return point, stats
178+
else:
179+
apoint = self.astep(array)
180+
point = self._logp_dlogp_func.array_to_full_dict(apoint)
181+
return point
182+
183+
160184
def metrop_select(mr, q, q0):
161185
"""Perform rejection/acceptance step for Metropolis class samplers.
162186

0 commit comments

Comments
 (0)