1+ import six
12import numpy as np
23import theano
34import theano .tensor as tt
@@ -168,17 +169,215 @@ def logpt(self):
168169 return tt .sum (self .logp_elemwiset )
169170
170171
171- class Model (Context , Factor ):
172- """Encapsulates the variables and likelihood factors of a model."""
172+ class InitContextMeta (type ):
173+ """Metaclass that executes `__init__` of instance in it's context"""
174+ def __call__ (cls , * args , ** kwargs ):
175+ instance = cls .__new__ (cls , * args , ** kwargs )
176+ with instance : # appends context
177+ instance .__init__ (* args , ** kwargs )
178+ return instance
179+
180+
181+ def withparent (meth ):
182+ """Helper wrapper that passes calls to parent's instance"""
183+ def wrapped (self , * args , ** kwargs ):
184+ res = meth (self , * args , ** kwargs )
185+ if getattr (self , 'parent' , None ) is not None :
186+ getattr (self .parent , meth .__name__ )(* args , ** kwargs )
187+ return res
188+ # Unfortunately functools wrapper fails
189+ # when decorating built-in methods so we
190+ # need to fix that improper behaviour
191+ wrapped .__name__ = meth .__name__
192+ return wrapped
193+
194+
195+ class treelist (list ):
196+ """A list that passes mutable extending operations used in Model
197+ to parent list instance.
198+ Extending treelist you will also extend its parent
199+ """
200+ def __init__ (self , iterable = (), parent = None ):
201+ super (treelist , self ).__init__ (iterable )
202+ assert isinstance (parent , list ) or parent is None
203+ self .parent = parent
204+ if self .parent is not None :
205+ self .parent .extend (self )
206+ # typechecking here works bad
207+ append = withparent (list .append )
208+ __iadd__ = withparent (list .__iadd__ )
209+ extend = withparent (list .extend )
210+
211+ def tree_contains (self , item ):
212+ if isinstance (self .parent , treedict ):
213+ return (list .__contains__ (self , item ) or
214+ self .parent .tree_contains (item ))
215+ elif isinstance (self .parent , list ):
216+ return (list .__contains__ (self , item ) or
217+ self .parent .__contains__ (item ))
218+ else :
219+ return list .__contains__ (self , item )
220+
221+ def __setitem__ (self , key , value ):
222+ raise NotImplementedError ('Method is removed as we are not'
223+ ' able to determine '
224+ 'appropriate logic for it' )
225+
226+ def __imul__ (self , other ):
227+ t0 = len (self )
228+ list .__imul__ (self , other )
229+ if self .parent is not None :
230+ self .parent .extend (self [t0 :])
231+
232+
233+ class treedict (dict ):
234+ """A dict that passes mutable extending operations used in Model
235+ to parent dict instance.
236+ Extending treedict you will also extend its parent
237+ """
238+ def __init__ (self , iterable = (), parent = None , ** kwargs ):
239+ super (treedict , self ).__init__ (iterable , ** kwargs )
240+ assert isinstance (parent , dict ) or parent is None
241+ self .parent = parent
242+ if self .parent is not None :
243+ self .parent .update (self )
244+ # typechecking here works bad
245+ __setitem__ = withparent (dict .__setitem__ )
246+ update = withparent (dict .update )
247+
248+ def tree_contains (self , item ):
249+ # needed for `add_random_variable` method
250+ if isinstance (self .parent , treedict ):
251+ return (dict .__contains__ (self , item ) or
252+ self .parent .tree_contains (item ))
253+ elif isinstance (self .parent , dict ):
254+ return (dict .__contains__ (self , item ) or
255+ self .parent .__contains__ (item ))
256+ else :
257+ return dict .__contains__ (self , item )
258+
259+
260+ class Model (six .with_metaclass (InitContextMeta , Context , Factor )):
261+ """Encapsulates the variables and likelihood factors of a model.
173262
174- def __init__ (self ):
175- self .named_vars = {}
176- self .free_RVs = []
177- self .observed_RVs = []
178- self .deterministics = []
179- self .potentials = []
180- self .missing_values = []
181- self .model = self
263+ Model class can be used for creating class based models. To create
264+ a class based model you should inherit from `Model` and
265+ override `__init__` with arbitrary definitions
266+ (do not forget to call base class `__init__` first).
267+
268+ Parameters
269+ ----------
270+ name : str, default '' - name that will be used as prefix for
271+ names of all random variables defined within model
272+ model : Model, default None - instance of Model that is
273+ supposed to be a parent for the new instance. If None,
274+ context will be used. All variables defined within instance
275+ will be passed to the parent instance. So that 'nested' model
276+ contributes to the variables and likelihood factors of
277+ parent model.
278+
279+ Examples
280+ --------
281+ # How to define a custom model
282+ class CustomModel(Model):
283+ # 1) override init
284+ def __init__(self, mean=0, sd=1, name='', model=None):
285+ # 2) call super's init first, passing model and name to it
286+ # name will be prefix for all variables here
287+ # if no name specified for model there will be no prefix
288+ super(CustomModel, self).__init__(name, model)
289+ # now you are in the context of instance,
290+ # `modelcontext` will return self
291+ # you can define variables in several ways
292+ # note, that all variables will get model's name prefix
293+
294+ # 3) you can create variables with Var method
295+ self.Var('v1', Normal.dist(mu=mean, sd=sd))
296+ # this will create variable named like '{prefix_}v1'
297+ # and assign attribute 'v1' to instance
298+ # created variable can be accessed with self.v1 or self['v1']
299+
300+ # 4) this syntax will also work as we are in the context
301+ # of instance itself, names are given as usual
302+ Normal('v2', mu=mean, sd=sd)
303+
304+ # something more complex is allowed too
305+ Normal('v3', mu=mean, sd=HalfCauchy('sd', beta=10, testval=1.))
306+
307+ # Deterministic variables can be used in usual way
308+ Deterministic('v3_sq', self.v3 ** 2)
309+ # Potentials too
310+ Potential('p1', tt.constant(1))
311+
312+ # After defining a class CustomModel you can use it in several ways
313+
314+ # I:
315+ # state the model within a context
316+ with Model() as model:
317+ CustomModel()
318+ # arbitrary actions
319+
320+ # II:
321+ # use new class as entering point in context
322+ with CustomModel() as model:
323+ Normal('new_normal_var', mu=1, sd=0)
324+
325+ # III:
326+ # just get model instance with all that was defined in it
327+ model = CustomModel()
328+
329+ # IV:
330+ # use many custom models within one context
331+ with Model() as model:
332+ CustomModel(mean=1, name='first')
333+ CustomModel(mean=2, name='second')
334+ """
335+ def __new__ (cls , * args , ** kwargs ):
336+ # resolves the parent instance
337+ instance = object .__new__ (cls )
338+ if kwargs .get ('model' ) is not None :
339+ instance ._parent = kwargs .get ('model' )
340+ elif cls .get_contexts ():
341+ instance ._parent = cls .get_contexts ()[- 1 ]
342+ else :
343+ instance ._parent = None
344+ return instance
345+
346+ def __init__ (self , name = '' , model = None ):
347+ self .name = name
348+ if self .parent is not None :
349+ self .named_vars = treedict (parent = self .parent .named_vars )
350+ self .free_RVs = treelist (parent = self .parent .free_RVs )
351+ self .observed_RVs = treelist (parent = self .parent .observed_RVs )
352+ self .deterministics = treelist (parent = self .parent .deterministics )
353+ self .potentials = treelist (parent = self .parent .potentials )
354+ self .missing_values = treelist (parent = self .parent .missing_values )
355+ else :
356+ self .named_vars = treedict ()
357+ self .free_RVs = treelist ()
358+ self .observed_RVs = treelist ()
359+ self .deterministics = treelist ()
360+ self .potentials = treelist ()
361+ self .missing_values = treelist ()
362+
363+ @property
364+ def model (self ):
365+ return self
366+
367+ @property
368+ def parent (self ):
369+ return self ._parent
370+
371+ @property
372+ def root (self ):
373+ model = self
374+ while not model .isroot :
375+ model = model .parent
376+ return model
377+
378+ @property
379+ def isroot (self ):
380+ return self .parent is None
182381
183382 @property
184383 @memoize
@@ -271,6 +470,7 @@ def Var(self, name, dist, data=None):
271470 -------
272471 FreeRV or ObservedRV
273472 """
473+ name = self .name_for (name )
274474 if data is None :
275475 if getattr (dist , "transform" , None ) is None :
276476 var = FreeRV (name = name , distribution = dist , model = self )
@@ -308,15 +508,46 @@ def Var(self, name, dist, data=None):
308508
309509 def add_random_variable (self , var ):
310510 """Add a random variable to the named variables of the model."""
311- if var . name in self .named_vars :
511+ if self .named_vars . tree_contains ( var . name ) :
312512 raise ValueError (
313513 "Variable name {} already exists." .format (var .name ))
314514 self .named_vars [var .name ] = var
315- if not hasattr (self , var .name ):
316- setattr (self , var .name , var )
515+ if not hasattr (self , self .name_of (var .name )):
516+ setattr (self , self .name_of (var .name ), var )
517+
518+ @property
519+ def prefix (self ):
520+ return '%s_' % self .name if self .name else ''
521+
522+ def name_for (self , name ):
523+ """Checks if name has prefix and adds if needed
524+ """
525+ if self .prefix :
526+ if not name .startswith (self .prefix ):
527+ return '{}{}' .format (self .prefix , name )
528+ else :
529+ return name
530+ else :
531+ return name
532+
533+ def name_of (self , name ):
534+ """Checks if name has prefix and deletes if needed
535+ """
536+ if not self .prefix or not name :
537+ return name
538+ elif name .startswith (self .prefix ):
539+ return name [len (self .prefix ):]
540+ else :
541+ return name
317542
318543 def __getitem__ (self , key ):
319- return self .named_vars [key ]
544+ try :
545+ return self .named_vars [key ]
546+ except KeyError as e :
547+ try :
548+ return self .named_vars [self .name_for (key )]
549+ except KeyError :
550+ raise e
320551
321552 @memoize
322553 def makefn (self , outs , mode = None , * args , ** kwargs ):
@@ -633,9 +864,10 @@ def Deterministic(name, var, model=None):
633864 -------
634865 n : var but with name name
635866 """
636- var .name = name
637- modelcontext (model ).deterministics .append (var )
638- modelcontext (model ).add_random_variable (var )
867+ model = modelcontext (model )
868+ var .name = model .name_for (name )
869+ model .deterministics .append (var )
870+ model .add_random_variable (var )
639871 return var
640872
641873
@@ -651,8 +883,9 @@ def Potential(name, var, model=None):
651883 -------
652884 var : var, with name attribute
653885 """
654- var .name = name
655- modelcontext (model ).potentials .append (var )
886+ model = modelcontext (model )
887+ var .name = model .name_for (name )
888+ model .potentials .append (var )
656889 return var
657890
658891
0 commit comments