Skip to content

Commit 55c8ce6

Browse files
ferrinetwiecki
authored andcommitted
ENH User model (#1525)
* Started to write Base class for pymc3.models * mode `add_var` to public api * Added some docstrings * Added some docstrings * added getitem and fixed a typo * added assertion check * added resolve var method * decided not to add resolve method * Added linear component * Docs fix * patsy's intercept is inited properly now * refactored code * updated docs * added possibility to init coefficients with random variables * added glm * refactored api, fixed formula init * refactored linear model, extended acceptable types * moved useful matrix and labels creation to utils file * code style * removed redundant evaluation of shape * refactored resolver for constructing matrix and labels * changed error message * changed signature of init * simplified utils any_to_tensor_and_labels code * tests for `any_to_tensor_and_labels` * added docstring for `any_to_tensor_and_labels` util * forgot to document return type in `any_to_tensor_and_labels` * refactored code for dict * dict tests fix(do not check labels there) * added access to random vars of model * added a shortcut for all variables so there is a unified way to get them * added default priors for linear model * update docs for linear * refactored UserModel api, made it more similar to pm.Model class * Lots of refactoring, tests for base class, more plain api design * deleted unused module variable * fixed some typos in docstring * Refactored pm.Model class, now it is ready for inheritance * Added documentation for Model class * Small typo in docstring * nested contains for treedict (needed for add_random_variable) * More accurate duplicate implementation of treedict/treelist * refactored treedict/treelist * changed `__imul__` of treelist * added `root` property and `isroot` indicator for base model * protect `parent` and `model` attributes from violation * travis' python2 did not fail on bad syntax(maybe it's too new), fixed * decided not to use functools wrapper Unfortunately functools wrapper fails when decorating built-in methods so I need to fix that improper behaviour. Some bad but needed tricks were implemented * Added models package to setup script * Refactor utils * Fix some typos in pm.model
1 parent 0ebaacd commit 55c8ce6

File tree

9 files changed

+829
-21
lines changed

9 files changed

+829
-21
lines changed

pymc3/model.py

Lines changed: 252 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import six
12
import numpy as np
23
import theano
34
import theano.tensor as tt
@@ -168,17 +169,215 @@ def logpt(self):
168169
return tt.sum(self.logp_elemwiset)
169170

170171

171-
class Model(Context, Factor):
172-
"""Encapsulates the variables and likelihood factors of a model."""
172+
class InitContextMeta(type):
173+
"""Metaclass that executes `__init__` of instance in it's context"""
174+
def __call__(cls, *args, **kwargs):
175+
instance = cls.__new__(cls, *args, **kwargs)
176+
with instance: # appends context
177+
instance.__init__(*args, **kwargs)
178+
return instance
179+
180+
181+
def withparent(meth):
182+
"""Helper wrapper that passes calls to parent's instance"""
183+
def wrapped(self, *args, **kwargs):
184+
res = meth(self, *args, **kwargs)
185+
if getattr(self, 'parent', None) is not None:
186+
getattr(self.parent, meth.__name__)(*args, **kwargs)
187+
return res
188+
# Unfortunately functools wrapper fails
189+
# when decorating built-in methods so we
190+
# need to fix that improper behaviour
191+
wrapped.__name__ = meth.__name__
192+
return wrapped
193+
194+
195+
class treelist(list):
196+
"""A list that passes mutable extending operations used in Model
197+
to parent list instance.
198+
Extending treelist you will also extend its parent
199+
"""
200+
def __init__(self, iterable=(), parent=None):
201+
super(treelist, self).__init__(iterable)
202+
assert isinstance(parent, list) or parent is None
203+
self.parent = parent
204+
if self.parent is not None:
205+
self.parent.extend(self)
206+
# typechecking here works bad
207+
append = withparent(list.append)
208+
__iadd__ = withparent(list.__iadd__)
209+
extend = withparent(list.extend)
210+
211+
def tree_contains(self, item):
212+
if isinstance(self.parent, treedict):
213+
return (list.__contains__(self, item) or
214+
self.parent.tree_contains(item))
215+
elif isinstance(self.parent, list):
216+
return (list.__contains__(self, item) or
217+
self.parent.__contains__(item))
218+
else:
219+
return list.__contains__(self, item)
220+
221+
def __setitem__(self, key, value):
222+
raise NotImplementedError('Method is removed as we are not'
223+
' able to determine '
224+
'appropriate logic for it')
225+
226+
def __imul__(self, other):
227+
t0 = len(self)
228+
list.__imul__(self, other)
229+
if self.parent is not None:
230+
self.parent.extend(self[t0:])
231+
232+
233+
class treedict(dict):
234+
"""A dict that passes mutable extending operations used in Model
235+
to parent dict instance.
236+
Extending treedict you will also extend its parent
237+
"""
238+
def __init__(self, iterable=(), parent=None, **kwargs):
239+
super(treedict, self).__init__(iterable, **kwargs)
240+
assert isinstance(parent, dict) or parent is None
241+
self.parent = parent
242+
if self.parent is not None:
243+
self.parent.update(self)
244+
# typechecking here works bad
245+
__setitem__ = withparent(dict.__setitem__)
246+
update = withparent(dict.update)
247+
248+
def tree_contains(self, item):
249+
# needed for `add_random_variable` method
250+
if isinstance(self.parent, treedict):
251+
return (dict.__contains__(self, item) or
252+
self.parent.tree_contains(item))
253+
elif isinstance(self.parent, dict):
254+
return (dict.__contains__(self, item) or
255+
self.parent.__contains__(item))
256+
else:
257+
return dict.__contains__(self, item)
258+
259+
260+
class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
261+
"""Encapsulates the variables and likelihood factors of a model.
173262
174-
def __init__(self):
175-
self.named_vars = {}
176-
self.free_RVs = []
177-
self.observed_RVs = []
178-
self.deterministics = []
179-
self.potentials = []
180-
self.missing_values = []
181-
self.model = self
263+
Model class can be used for creating class based models. To create
264+
a class based model you should inherit from `Model` and
265+
override `__init__` with arbitrary definitions
266+
(do not forget to call base class `__init__` first).
267+
268+
Parameters
269+
----------
270+
name : str, default '' - name that will be used as prefix for
271+
names of all random variables defined within model
272+
model : Model, default None - instance of Model that is
273+
supposed to be a parent for the new instance. If None,
274+
context will be used. All variables defined within instance
275+
will be passed to the parent instance. So that 'nested' model
276+
contributes to the variables and likelihood factors of
277+
parent model.
278+
279+
Examples
280+
--------
281+
# How to define a custom model
282+
class CustomModel(Model):
283+
# 1) override init
284+
def __init__(self, mean=0, sd=1, name='', model=None):
285+
# 2) call super's init first, passing model and name to it
286+
# name will be prefix for all variables here
287+
# if no name specified for model there will be no prefix
288+
super(CustomModel, self).__init__(name, model)
289+
# now you are in the context of instance,
290+
# `modelcontext` will return self
291+
# you can define variables in several ways
292+
# note, that all variables will get model's name prefix
293+
294+
# 3) you can create variables with Var method
295+
self.Var('v1', Normal.dist(mu=mean, sd=sd))
296+
# this will create variable named like '{prefix_}v1'
297+
# and assign attribute 'v1' to instance
298+
# created variable can be accessed with self.v1 or self['v1']
299+
300+
# 4) this syntax will also work as we are in the context
301+
# of instance itself, names are given as usual
302+
Normal('v2', mu=mean, sd=sd)
303+
304+
# something more complex is allowed too
305+
Normal('v3', mu=mean, sd=HalfCauchy('sd', beta=10, testval=1.))
306+
307+
# Deterministic variables can be used in usual way
308+
Deterministic('v3_sq', self.v3 ** 2)
309+
# Potentials too
310+
Potential('p1', tt.constant(1))
311+
312+
# After defining a class CustomModel you can use it in several ways
313+
314+
# I:
315+
# state the model within a context
316+
with Model() as model:
317+
CustomModel()
318+
# arbitrary actions
319+
320+
# II:
321+
# use new class as entering point in context
322+
with CustomModel() as model:
323+
Normal('new_normal_var', mu=1, sd=0)
324+
325+
# III:
326+
# just get model instance with all that was defined in it
327+
model = CustomModel()
328+
329+
# IV:
330+
# use many custom models within one context
331+
with Model() as model:
332+
CustomModel(mean=1, name='first')
333+
CustomModel(mean=2, name='second')
334+
"""
335+
def __new__(cls, *args, **kwargs):
336+
# resolves the parent instance
337+
instance = object.__new__(cls)
338+
if kwargs.get('model') is not None:
339+
instance._parent = kwargs.get('model')
340+
elif cls.get_contexts():
341+
instance._parent = cls.get_contexts()[-1]
342+
else:
343+
instance._parent = None
344+
return instance
345+
346+
def __init__(self, name='', model=None):
347+
self.name = name
348+
if self.parent is not None:
349+
self.named_vars = treedict(parent=self.parent.named_vars)
350+
self.free_RVs = treelist(parent=self.parent.free_RVs)
351+
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
352+
self.deterministics = treelist(parent=self.parent.deterministics)
353+
self.potentials = treelist(parent=self.parent.potentials)
354+
self.missing_values = treelist(parent=self.parent.missing_values)
355+
else:
356+
self.named_vars = treedict()
357+
self.free_RVs = treelist()
358+
self.observed_RVs = treelist()
359+
self.deterministics = treelist()
360+
self.potentials = treelist()
361+
self.missing_values = treelist()
362+
363+
@property
364+
def model(self):
365+
return self
366+
367+
@property
368+
def parent(self):
369+
return self._parent
370+
371+
@property
372+
def root(self):
373+
model = self
374+
while not model.isroot:
375+
model = model.parent
376+
return model
377+
378+
@property
379+
def isroot(self):
380+
return self.parent is None
182381

183382
@property
184383
@memoize
@@ -271,6 +470,7 @@ def Var(self, name, dist, data=None):
271470
-------
272471
FreeRV or ObservedRV
273472
"""
473+
name = self.name_for(name)
274474
if data is None:
275475
if getattr(dist, "transform", None) is None:
276476
var = FreeRV(name=name, distribution=dist, model=self)
@@ -308,15 +508,46 @@ def Var(self, name, dist, data=None):
308508

309509
def add_random_variable(self, var):
310510
"""Add a random variable to the named variables of the model."""
311-
if var.name in self.named_vars:
511+
if self.named_vars.tree_contains(var.name):
312512
raise ValueError(
313513
"Variable name {} already exists.".format(var.name))
314514
self.named_vars[var.name] = var
315-
if not hasattr(self, var.name):
316-
setattr(self, var.name, var)
515+
if not hasattr(self, self.name_of(var.name)):
516+
setattr(self, self.name_of(var.name), var)
517+
518+
@property
519+
def prefix(self):
520+
return '%s_' % self.name if self.name else ''
521+
522+
def name_for(self, name):
523+
"""Checks if name has prefix and adds if needed
524+
"""
525+
if self.prefix:
526+
if not name.startswith(self.prefix):
527+
return '{}{}'.format(self.prefix, name)
528+
else:
529+
return name
530+
else:
531+
return name
532+
533+
def name_of(self, name):
534+
"""Checks if name has prefix and deletes if needed
535+
"""
536+
if not self.prefix or not name:
537+
return name
538+
elif name.startswith(self.prefix):
539+
return name[len(self.prefix):]
540+
else:
541+
return name
317542

318543
def __getitem__(self, key):
319-
return self.named_vars[key]
544+
try:
545+
return self.named_vars[key]
546+
except KeyError as e:
547+
try:
548+
return self.named_vars[self.name_for(key)]
549+
except KeyError:
550+
raise e
320551

321552
@memoize
322553
def makefn(self, outs, mode=None, *args, **kwargs):
@@ -633,9 +864,10 @@ def Deterministic(name, var, model=None):
633864
-------
634865
n : var but with name name
635866
"""
636-
var.name = name
637-
modelcontext(model).deterministics.append(var)
638-
modelcontext(model).add_random_variable(var)
867+
model = modelcontext(model)
868+
var.name = model.name_for(name)
869+
model.deterministics.append(var)
870+
model.add_random_variable(var)
639871
return var
640872

641873

@@ -651,8 +883,9 @@ def Potential(name, var, model=None):
651883
-------
652884
var : var, with name attribute
653885
"""
654-
var.name = name
655-
modelcontext(model).potentials.append(var)
886+
model = modelcontext(model)
887+
var.name = model.name_for(name)
888+
model.potentials.append(var)
656889
return var
657890

658891

pymc3/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .linear import LinearComponent, Glm
2+
3+
__all__ = [
4+
'LinearComponent',
5+
'Glm'
6+
]

0 commit comments

Comments
 (0)