@@ -184,9 +184,12 @@ def fastd2logp(self, vars=None):
184
184
def logpt (self ):
185
185
"""Theano scalar of log-probability of the model"""
186
186
if getattr (self , 'total_size' , None ) is not None :
187
- return tt .sum (self .logp_elemwiset ) * self .scaling
187
+ logp = tt .sum (self .logp_elemwiset ) * self .scaling
188
188
else :
189
- return tt .sum (self .logp_elemwiset )
189
+ logp = tt .sum (self .logp_elemwiset )
190
+ if self .name is not None :
191
+ logp .name = '__logp_%s' % self .name
192
+ return logp
190
193
191
194
192
195
class InitContextMeta (type ):
@@ -277,6 +280,173 @@ def tree_contains(self, item):
277
280
return dict .__contains__ (self , item )
278
281
279
282
283
+ class ValueGradFunction (object ):
284
+ """Create a theano function that computes a value and its gradient.
285
+
286
+ Parameters
287
+ ----------
288
+ cost : theano variable
289
+ The value that we compute with its gradient.
290
+ grad_vars : list of named theano variables or None
291
+ The arguments with respect to which the gradient is computed.
292
+ extra_args : list of named theano variables or None
293
+ Other arguments of the function that are assumed constant. They
294
+ are stored in shared variables and can be set using
295
+ `set_extra_values`.
296
+ dtype : str, default=theano.config.floatX
297
+ The dtype of the arrays.
298
+ casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
299
+ Casting rule for casting `grad_args` to the array dtype.
300
+ See `numpy.can_cast` for a description of the options.
301
+ Keep in mind that we cast the variables to the array *and*
302
+ back from the array dtype to the variable dtype.
303
+ kwargs
304
+ Extra arguments are passed on to `theano.function`.
305
+
306
+ Attributes
307
+ ----------
308
+ size : int
309
+ The number of elements in the parameter array.
310
+ profile : theano profiling object or None
311
+ The profiling object of the theano function that computes value and
312
+ gradient. This is None unless `profile=True` was set in the
313
+ kwargs.
314
+ """
315
+ def __init__ (self , cost , grad_vars , extra_vars = None , dtype = None ,
316
+ casting = 'no' , ** kwargs ):
317
+ if extra_vars is None :
318
+ extra_vars = []
319
+
320
+ names = [arg .name for arg in grad_vars + extra_vars ]
321
+ if any (name is None for name in names ):
322
+ raise ValueError ('Arguments must be named.' )
323
+ if len (set (names )) != len (names ):
324
+ raise ValueError ('Names of the arguments are not unique.' )
325
+
326
+ if cost .ndim > 0 :
327
+ raise ValueError ('Cost must be a scalar.' )
328
+
329
+ self ._grad_vars = grad_vars
330
+ self ._extra_vars = extra_vars
331
+ self ._extra_var_names = set (var .name for var in extra_vars )
332
+ self ._cost = cost
333
+ self ._ordering = ArrayOrdering (grad_vars )
334
+ self .size = self ._ordering .size
335
+ self ._extra_are_set = False
336
+ if dtype is None :
337
+ dtype = theano .config .floatX
338
+ self .dtype = dtype
339
+ for var in self ._grad_vars :
340
+ if not np .can_cast (var .dtype , self .dtype , casting ):
341
+ raise TypeError ('Invalid dtype for variable %s. Can not '
342
+ 'cast to %s with casting rule %s.'
343
+ % (var .name , self .dtype , casting ))
344
+ if not np .issubdtype (var .dtype , float ):
345
+ raise TypeError ('Invalid dtype for variable %s. Must be '
346
+ 'floating point but is %s.'
347
+ % (var .name , var .dtype ))
348
+
349
+ givens = []
350
+ self ._extra_vars_shared = {}
351
+ for var in extra_vars :
352
+ shared = theano .shared (var .tag .test_value , var .name + '_shared__' )
353
+ self ._extra_vars_shared [var .name ] = shared
354
+ givens .append ((var , shared ))
355
+
356
+ self ._vars_joined , self ._cost_joined = self ._build_joined (
357
+ self ._cost , grad_vars , self ._ordering .vmap )
358
+
359
+ grad = tt .grad (self ._cost_joined , self ._vars_joined )
360
+ grad .name = '__grad'
361
+
362
+ inputs = [self ._vars_joined ]
363
+
364
+ self ._theano_function = theano .function (
365
+ inputs , [self ._cost_joined , grad ], givens = givens , ** kwargs )
366
+
367
+ def set_extra_values (self , extra_vars ):
368
+ self ._extra_are_set = True
369
+ for var in self ._extra_vars :
370
+ self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ])
371
+
372
+ def get_extra_values (self ):
373
+ if not self ._extra_are_set :
374
+ raise ValueError ('Extra values are not set.' )
375
+
376
+ return {var .name : self ._extra_vars_shared [var .name ].get_value ()
377
+ for var in self ._extra_vars }
378
+
379
+ def __call__ (self , array , grad_out = None , extra_vars = None ):
380
+ if extra_vars is not None :
381
+ self .set_extra_values (extra_vars )
382
+
383
+ if not self ._extra_are_set :
384
+ raise ValueError ('Extra values are not set.' )
385
+
386
+ if array .shape != (self .size ,):
387
+ raise ValueError ('Invalid shape for array. Must be %s but is %s.'
388
+ % ((self .size ,), array .shape ))
389
+
390
+ if grad_out is None :
391
+ out = np .empty_like (array )
392
+ else :
393
+ out = grad_out
394
+
395
+ logp , dlogp = self ._theano_function (array )
396
+ if grad_out is None :
397
+ return logp , dlogp
398
+ else :
399
+ out [...] = dlogp
400
+ return logp
401
+
402
+ @property
403
+ def profile (self ):
404
+ """Profiling information of the underlying theano function."""
405
+ return self ._theano_function .profile
406
+
407
+ def dict_to_array (self , point ):
408
+ """Convert a dictionary with values for grad_vars to an array."""
409
+ array = np .empty (self .size , dtype = self .dtype )
410
+ for varmap in self ._ordering .vmap :
411
+ array [varmap .slc ] = point [varmap .var ].ravel ().astype (self .dtype )
412
+ return array
413
+
414
+ def array_to_dict (self , array ):
415
+ """Convert an array to a dictionary containing the grad_vars."""
416
+ if array .shape != (self .size ,):
417
+ raise ValueError ('Array should have shape (%s,) but has %s'
418
+ % (self .size , array .shape ))
419
+ if array .dtype != self .dtype :
420
+ raise ValueError ('Array has invalid dtype. Should be %s but is %s'
421
+ % (self ._dtype , self .dtype ))
422
+ point = {}
423
+ for varmap in self ._ordering .vmap :
424
+ data = array [varmap .slc ].reshape (varmap .shp )
425
+ point [varmap .var ] = data .astype (varmap .dtyp )
426
+
427
+ return point
428
+
429
+ def array_to_full_dict (self , array ):
430
+ """Convert an array to a dictionary with grad_vars and extra_vars."""
431
+ point = self .array_to_dict (array )
432
+ for name , var in self ._extra_vars_shared .items ():
433
+ point [name ] = var .get_value ()
434
+ return point
435
+
436
+ def _build_joined (self , cost , args , vmap ):
437
+ args_joined = tt .vector ('__args_joined' )
438
+ args_joined .tag .test_value = np .zeros (self .size , dtype = self .dtype )
439
+
440
+ joined_slices = {}
441
+ for vmap in vmap :
442
+ sliced = args_joined [vmap .slc ].reshape (vmap .shp )
443
+ sliced .name = vmap .var
444
+ joined_slices [vmap .var ] = sliced
445
+
446
+ replace = {var : joined_slices [var .name ] for var in args }
447
+ return args_joined , theano .clone (cost , replace = replace )
448
+
449
+
280
450
class Model (six .with_metaclass (InitContextMeta , Context , Factor )):
281
451
"""Encapsulates the variables and likelihood factors of a model.
282
452
@@ -419,7 +589,6 @@ def bijection(self):
419
589
return bij
420
590
421
591
@property
422
- @memoize
423
592
def dict_to_array (self ):
424
593
return self .bijection .map
425
594
@@ -428,23 +597,34 @@ def ndim(self):
428
597
return sum (var .dsize for var in self .free_RVs )
429
598
430
599
@property
431
- @memoize
432
600
def logp_array (self ):
433
601
return self .bijection .mapf (self .fastlogp )
434
602
435
603
@property
436
- @memoize
437
604
def dlogp_array (self ):
438
605
vars = inputvars (self .cont_vars )
439
606
return self .bijection .mapf (self .fastdlogp (vars ))
440
607
608
+ def logp_dlogp_function (self , grad_vars = None , ** kwargs ):
609
+ if grad_vars is None :
610
+ grad_vars = list (typefilter (self .free_RVs , continuous_types ))
611
+ else :
612
+ for var in grad_vars :
613
+ if var .dtype not in continuous_types :
614
+ raise ValueError ("Can only compute the gradient of "
615
+ "continuous types: %s" % var )
616
+ varnames = [var .name for var in grad_vars ]
617
+ extra_vars = [var for var in self .free_RVs if var .name not in varnames ]
618
+ return ValueGradFunction (self .logpt , grad_vars , extra_vars , ** kwargs )
619
+
441
620
@property
442
- @memoize
443
621
def logpt (self ):
444
622
"""Theano scalar of log-probability of the model"""
445
623
with self :
446
624
factors = [var .logpt for var in self .basic_RVs ] + self .potentials
447
- return tt .add (* map (tt .sum , factors ))
625
+ logp = tt .add (* map (tt .sum , factors ))
626
+ logp .name = '__logp'
627
+ return logp
448
628
449
629
@property
450
630
def varlogpt (self ):
@@ -595,7 +775,6 @@ def __getitem__(self, key):
595
775
except KeyError :
596
776
raise e
597
777
598
- @memoize
599
778
def makefn (self , outs , mode = None , * args , ** kwargs ):
600
779
"""Compiles a Theano function which returns `outs` and takes the variable
601
780
ancestors of `outs` as inputs.
0 commit comments