Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 252 additions & 19 deletions pymc3/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import six
import numpy as np
import theano
import theano.tensor as tt
Expand Down Expand Up @@ -168,17 +169,215 @@ def logpt(self):
return tt.sum(self.logp_elemwiset)


class Model(Context, Factor):
"""Encapsulates the variables and likelihood factors of a model."""
class InitContextMeta(type):
"""Metaclass that executes `__init__` of instance in it's context"""
def __call__(cls, *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def withparent(meth):
"""Helper wrapper that passes calls to parent's instance"""
def wrapped(self, *args, **kwargs):
res = meth(self, *args, **kwargs)
if getattr(self, 'parent', None) is not None:
getattr(self.parent, meth.__name__)(*args, **kwargs)
return res
# Unfortunately functools wrapper fails
# when decorating built-in methods so we
# need to fix that improper behaviour
wrapped.__name__ = meth.__name__
return wrapped


class treelist(list):
"""A list that passes mutable extending operations used in Model
to parent list instance.
Extending treelist you will also extend its parent
"""
def __init__(self, iterable=(), parent=None):
super(treelist, self).__init__(iterable)
assert isinstance(parent, list) or parent is None
self.parent = parent
if self.parent is not None:
self.parent.extend(self)
# typechecking here works bad
append = withparent(list.append)
__iadd__ = withparent(list.__iadd__)
extend = withparent(list.extend)

def tree_contains(self, item):
if isinstance(self.parent, treedict):
return (list.__contains__(self, item) or
self.parent.tree_contains(item))
elif isinstance(self.parent, list):
return (list.__contains__(self, item) or
self.parent.__contains__(item))
else:
return list.__contains__(self, item)

def __setitem__(self, key, value):
raise NotImplementedError('Method is removed as we are not'
' able to determine '
'appropriate logic for it')

def __imul__(self, other):
t0 = len(self)
list.__imul__(self, other)
if self.parent is not None:
self.parent.extend(self[t0:])


class treedict(dict):
"""A dict that passes mutable extending operations used in Model
to parent dict instance.
Extending treedict you will also extend its parent
"""
def __init__(self, iterable=(), parent=None, **kwargs):
super(treedict, self).__init__(iterable, **kwargs)
assert isinstance(parent, dict) or parent is None
self.parent = parent
if self.parent is not None:
self.parent.update(self)
# typechecking here works bad
__setitem__ = withparent(dict.__setitem__)
update = withparent(dict.update)

def tree_contains(self, item):
# needed for `add_random_variable` method
if isinstance(self.parent, treedict):
return (dict.__contains__(self, item) or
self.parent.tree_contains(item))
elif isinstance(self.parent, dict):
return (dict.__contains__(self, item) or
self.parent.__contains__(item))
else:
return dict.__contains__(self, item)


class Model(six.with_metaclass(InitContextMeta, Context, Factor)):
"""Encapsulates the variables and likelihood factors of a model.

def __init__(self):
self.named_vars = {}
self.free_RVs = []
self.observed_RVs = []
self.deterministics = []
self.potentials = []
self.missing_values = []
self.model = self
Model class can be used for creating class based models. To create
a class based model you should inherit from `Model` and
override `__init__` with arbitrary definitions
(do not forget to call base class `__init__` first).

Parameters
----------
name : str, default '' - name that will be used as prefix for
names of all random variables defined within model
model : Model, default None - instance of Model that is
supposed to be a parent for the new instance. If None,
context will be used. All variables defined within instance
will be passed to the parent instance. So that 'nested' model
contributes to the variables and likelihood factors of
parent model.

Examples
--------
# How to define a custom model
class CustomModel(Model):
# 1) override init
def __init__(self, mean=0, sd=1, name='', model=None):
# 2) call super's init first, passing model and name to it
# name will be prefix for all variables here
# if no name specified for model there will be no prefix
super(CustomModel, self).__init__(name, model)
# now you are in the context of instance,
# `modelcontext` will return self
# you can define variables in several ways
# note, that all variables will get model's name prefix

# 3) you can create variables with Var method
self.Var('v1', Normal.dist(mu=mean, sd=sd))
# this will create variable named like '{prefix_}v1'
# and assign attribute 'v1' to instance
# created variable can be accessed with self.v1 or self['v1']

# 4) this syntax will also work as we are in the context
# of instance itself, names are given as usual
Normal('v2', mu=mean, sd=sd)

# something more complex is allowed too
Normal('v3', mu=mean, sd=HalfCauchy('sd', beta=10, testval=1.))

# Deterministic variables can be used in usual way
Deterministic('v3_sq', self.v3 ** 2)
# Potentials too
Potential('p1', tt.constant(1))

# After defining a class CustomModel you can use it in several ways

# I:
# state the model within a context
with Model() as model:
CustomModel()
# arbitrary actions

# II:
# use new class as entering point in context
with CustomModel() as model:
Normal('new_normal_var', mu=1, sd=0)

# III:
# just get model instance with all that was defined in it
model = CustomModel()

# IV:
# use many custom models within one context
with Model() as model:
CustomModel(mean=1, name='first')
CustomModel(mean=2, name='second')
"""
def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = object.__new__(cls)
if kwargs.get('model') is not None:
instance._parent = kwargs.get('model')
elif cls.get_contexts():
instance._parent = cls.get_contexts()[-1]
else:
instance._parent = None
return instance

def __init__(self, name='', model=None):
self.name = name
if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self.missing_values = treelist(parent=self.parent.missing_values)
else:
self.named_vars = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self.missing_values = treelist()

@property
def model(self):
return self

@property
def parent(self):
return self._parent

@property
def root(self):
model = self
while not model.isroot:
model = model.parent
return model

@property
def isroot(self):
return self.parent is None

@property
@memoize
Expand Down Expand Up @@ -271,6 +470,7 @@ def Var(self, name, dist, data=None):
-------
FreeRV or ObservedRV
"""
name = self.name_for(name)
if data is None:
if getattr(dist, "transform", None) is None:
var = FreeRV(name=name, distribution=dist, model=self)
Expand Down Expand Up @@ -308,15 +508,46 @@ def Var(self, name, dist, data=None):

def add_random_variable(self, var):
"""Add a random variable to the named variables of the model."""
if var.name in self.named_vars:
if self.named_vars.tree_contains(var.name):
raise ValueError(
"Variable name {} already exists.".format(var.name))
self.named_vars[var.name] = var
if not hasattr(self, var.name):
setattr(self, var.name, var)
if not hasattr(self, self.name_of(var.name)):
setattr(self, self.name_of(var.name), var)

@property
def prefix(self):
return '%s_' % self.name if self.name else ''

def name_for(self, name):
"""Checks if name has prefix and adds if needed
"""
if self.prefix:
if not name.startswith(self.prefix):
return '{}{}'.format(self.prefix, name)
else:
return name
else:
return name

def name_of(self, name):
"""Checks if name has prefix and deletes if needed
"""
if not self.prefix or not name:
return name
elif name.startswith(self.prefix):
return name[len(self.prefix):]
else:
return name

def __getitem__(self, key):
return self.named_vars[key]
try:
return self.named_vars[key]
except KeyError as e:
try:
return self.named_vars[self.name_for(key)]
except KeyError:
raise e

@memoize
def makefn(self, outs, mode=None, *args, **kwargs):
Expand Down Expand Up @@ -633,9 +864,10 @@ def Deterministic(name, var, model=None):
-------
n : var but with name name
"""
var.name = name
modelcontext(model).deterministics.append(var)
modelcontext(model).add_random_variable(var)
model = modelcontext(model)
var.name = model.name_for(name)
model.deterministics.append(var)
model.add_random_variable(var)
return var


Expand All @@ -651,8 +883,9 @@ def Potential(name, var, model=None):
-------
var : var, with name attribute
"""
var.name = name
modelcontext(model).potentials.append(var)
model = modelcontext(model)
var.name = model.name_for(name)
model.potentials.append(var)
return var


Expand Down
6 changes: 6 additions & 0 deletions pymc3/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .linear import LinearComponent, Glm

__all__ = [
'LinearComponent',
'Glm'
]
Loading