Skip to content

Commit ac8fe5a

Browse files
springcoilColCarroll
authored andcommitted
BUG: Attempt to fix 2909 (#2979)
Fixing up the test and implementation Adding other draw_values Small test fix
1 parent 850a2a7 commit ac8fe5a

File tree

4 files changed

+43
-27
lines changed

4 files changed

+43
-27
lines changed

pymc3/distributions/continuous.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(self, lower=0, upper=1, transform='interval',
163163

164164
def random(self, point=None, size=None):
165165
lower, upper = draw_values([self.lower, self.upper],
166-
point=point)
166+
point=point, size=size)
167167
return generate_samples(stats.uniform.rvs, loc=lower,
168168
scale=upper - lower,
169169
dist_shape=self.shape,
@@ -748,7 +748,7 @@ def __init__(self, lam, *args, **kwargs):
748748
assert_negative_support(lam, 'lam', 'Exponential')
749749

750750
def random(self, point=None, size=None):
751-
lam = draw_values([self.lam], point=point)[0]
751+
lam = draw_values([self.lam], point=point, size=size)[0]
752752
return generate_samples(np.random.exponential, scale=1. / lam,
753753
dist_shape=self.shape,
754754
size=size)
@@ -817,7 +817,7 @@ def __init__(self, mu, b, *args, **kwargs):
817817
assert_negative_support(b, 'b', 'Laplace')
818818

819819
def random(self, point=None, size=None):
820-
mu, b = draw_values([self.mu, self.b], point=point)
820+
mu, b = draw_values([self.mu, self.b], point=point, size=size)
821821
return generate_samples(np.random.laplace, mu, b,
822822
dist_shape=self.shape,
823823
size=size)
@@ -921,7 +921,7 @@ def _random(self, mu, tau, size=None):
921921
return np.exp(mu + (tau**-0.5) * samples)
922922

923923
def random(self, point=None, size=None):
924-
mu, tau = draw_values([self.mu, self.tau], point=point)
924+
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
925925
return generate_samples(self._random, mu, tau,
926926
dist_shape=self.shape,
927927
size=size)
@@ -1023,7 +1023,7 @@ def __init__(self, nu, mu=0, lam=None, sd=None, *args, **kwargs):
10231023

10241024
def random(self, point=None, size=None):
10251025
nu, mu, lam = draw_values([self.nu, self.mu, self.lam],
1026-
point=point)
1026+
point=point, size=size)
10271027
return generate_samples(stats.t.rvs, nu, loc=mu, scale=lam**-0.5,
10281028
dist_shape=self.shape,
10291029
size=size)
@@ -1121,7 +1121,7 @@ def _random(self, alpha, m, size=None):
11211121

11221122
def random(self, point=None, size=None):
11231123
alpha, m = draw_values([self.alpha, self.m],
1124-
point=point)
1124+
point=point, size=size)
11251125
return generate_samples(self._random, alpha, m,
11261126
dist_shape=self.shape,
11271127
size=size)
@@ -1202,7 +1202,7 @@ def _random(self, alpha, beta, size=None):
12021202

12031203
def random(self, point=None, size=None):
12041204
alpha, beta = draw_values([self.alpha, self.beta],
1205-
point=point)
1205+
point=point, size=size)
12061206
return generate_samples(self._random, alpha, beta,
12071207
dist_shape=self.shape,
12081208
size=size)
@@ -1276,7 +1276,7 @@ def _random(self, beta, size=None):
12761276
return beta * np.abs(np.tan(np.pi * (u - 0.5)))
12771277

12781278
def random(self, point=None, size=None):
1279-
beta = draw_values([self.beta], point=point)[0]
1279+
beta = draw_values([self.beta], point=point, size=size)[0]
12801280
return generate_samples(self._random, beta,
12811281
dist_shape=self.shape,
12821282
size=size)
@@ -1381,7 +1381,7 @@ def get_alpha_beta(self, alpha=None, beta=None, mu=None, sd=None):
13811381

13821382
def random(self, point=None, size=None):
13831383
alpha, beta = draw_values([self.alpha, self.beta],
1384-
point=point)
1384+
point=point, size=size)
13851385
return generate_samples(stats.gamma.rvs, alpha, scale=1. / beta,
13861386
dist_shape=self.shape,
13871387
size=size)
@@ -1474,7 +1474,7 @@ def _calculate_mean(self):
14741474

14751475
def random(self, point=None, size=None):
14761476
alpha, beta = draw_values([self.alpha, self.beta],
1477-
point=point)
1477+
point=point, size=size)
14781478
return generate_samples(stats.invgamma.rvs, a=alpha, scale=beta,
14791479
dist_shape=self.shape,
14801480
size=size)
@@ -1610,7 +1610,7 @@ def __init__(self, alpha, beta, *args, **kwargs):
16101610

16111611
def random(self, point=None, size=None):
16121612
alpha, beta = draw_values([self.alpha, self.beta],
1613-
point=point)
1613+
point=point, size=size)
16141614

16151615
def _random(a, b, size=None):
16161616
return b * (-np.log(np.random.uniform(size=size)))**(1 / a)
@@ -1708,7 +1708,7 @@ def __init__(self, nu=1, sd=None, lam=None, *args, **kwargs):
17081708
assert_negative_support(nu, 'nu', 'HalfStudentT')
17091709

17101710
def random(self, point=None, size=None):
1711-
nu, sd = draw_values([self.nu, self.sd], point=point)
1711+
nu, sd = draw_values([self.nu, self.sd], point=point, size=size)
17121712
return np.abs(generate_samples(stats.t.rvs, nu, loc=0, scale=sd,
17131713
dist_shape=self.shape,
17141714
size=size))
@@ -1813,7 +1813,7 @@ def __init__(self, mu, sigma, nu, *args, **kwargs):
18131813

18141814
def random(self, point=None, size=None):
18151815
mu, sigma, nu = draw_values([self.mu, self.sigma, self.nu],
1816-
point=point)
1816+
point=point, size=size)
18171817

18181818
def _random(mu, sigma, nu, size=None):
18191819
return (np.random.normal(mu, sigma, size=size)
@@ -1905,7 +1905,7 @@ def __init__(self, mu=0.0, kappa=None, transform='circular',
19051905

19061906
def random(self, point=None, size=None):
19071907
mu, kappa = draw_values([self.mu, self.kappa],
1908-
point=point)
1908+
point=point, size=size)
19091909
return generate_samples(stats.vonmises.rvs, loc=mu, kappa=kappa,
19101910
dist_shape=self.shape,
19111911
size=size)
@@ -2002,7 +2002,7 @@ def __init__(self, mu=0.0, sd=None, tau=None, alpha=1, *args, **kwargs):
20022002

20032003
def random(self, point=None, size=None):
20042004
mu, tau, _, alpha = draw_values(
2005-
[self.mu, self.tau, self.sd, self.alpha], point=point)
2005+
[self.mu, self.tau, self.sd, self.alpha], point=point, size=size)
20062006
return generate_samples(stats.skewnorm.rvs,
20072007
a=alpha, loc=mu, scale=tau**-0.5,
20082008
dist_shape=self.shape,
@@ -2095,7 +2095,7 @@ def __init__(self, lower=0, upper=1, c=0.5,
20952095

20962096
def random(self, point=None, size=None):
20972097
c, lower, upper = draw_values([self.c, self.lower, self.upper],
2098-
point=point)
2098+
point=point, size=size)
20992099
return generate_samples(stats.triang.rvs, c=c-lower, loc=lower, scale=upper-lower,
21002100
size=size, dist_shape=self.shape, random_state=None)
21012101

@@ -2178,7 +2178,7 @@ def __init__(self, mu=0, beta=1.0, **kwargs):
21782178
super(Gumbel, self).__init__(**kwargs)
21792179

21802180
def random(self, point=None, size=None):
2181-
mu, sd = draw_values([self.mu, self.beta], point=point)
2181+
mu, sd = draw_values([self.mu, self.beta], point=point, size=size)
21822182
return generate_samples(stats.gumbel_r.rvs, loc=mu, scale=sd,
21832183
dist_shape=self.shape,
21842184
size=size)
@@ -2257,7 +2257,7 @@ def logp(self, value):
22572257
-(value - mu) / s - tt.log(s) - 2 * tt.log1p(tt.exp(-(value - mu) / s)), s > 0)
22582258

22592259
def random(self, point=None, size=None):
2260-
mu, s = draw_values([self.mu, self.s], point=point)
2260+
mu, s = draw_values([self.mu, self.s], point=point, size=size)
22612261

22622262
return generate_samples(
22632263
stats.logistic.rvs,
@@ -2333,7 +2333,7 @@ def __init__(self, mu=0, sd=None, tau=None, **kwargs):
23332333
super(LogitNormal, self).__init__(**kwargs)
23342334

23352335
def random(self, point=None, size=None):
2336-
mu, _, sd = draw_values([self.mu, self.tau, self.sd], point=point)
2336+
mu, _, sd = draw_values([self.mu, self.tau, self.sd], point=point, size=size)
23372337
return expit(generate_samples(stats.norm.rvs, loc=mu, scale=sd, dist_shape=self.shape,
23382338
size=size))
23392339

pymc3/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, n, p, *args, **kwargs):
6666
self.mode = tt.cast(tround(n * p), self.dtype)
6767

6868
def random(self, point=None, size=None):
69-
n, p = draw_values([self.n, self.p], point=point)
69+
n, p = draw_values([self.n, self.p], point=point, size=size)
7070
return generate_samples(stats.binom.rvs, n=n, p=p,
7171
dist_shape=self.shape,
7272
size=size)

pymc3/distributions/distribution.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def random(self, *args, **kwargs):
210210

211211

212212

213-
def draw_values(params, point=None):
213+
def draw_values(params, point=None, size=None):
214214
"""
215215
Draw (fix) parameter values. Handles a number of cases:
216216
@@ -251,10 +251,10 @@ def draw_values(params, point=None):
251251
named_nodes_children[k] = nnc[k]
252252
else:
253253
named_nodes_children[k].update(nnc[k])
254-
254+
255255
# Init givens and the stack of nodes to try to `_draw_value` from
256256
givens = {}
257-
stored = set([]) # Some nodes
257+
stored = set([]) # Some nodes
258258
stack = list(leaf_nodes.values()) # A queue would be more appropriate
259259
while stack:
260260
next_ = stack.pop(0)
@@ -269,7 +269,7 @@ def draw_values(params, point=None):
269269
# we can skip it. Furthermore, if this node was treated as a
270270
# TensorVariable that should be compiled by theano in
271271
# _compile_theano_function, it would raise a `TypeError:
272-
# ('Constants not allowed in param list', ...)` for
272+
# ('Constants not allowed in param list', ...)` for
273273
# TensorConstant, and a `TypeError: Cannot use a shared
274274
# variable (...) as explicit input` for SharedVariable.
275275
stored.add(next_.name)
@@ -285,7 +285,7 @@ def draw_values(params, point=None):
285285
# have the random method
286286
givens[next_.name] = (next_, _draw_value(next_,
287287
point=point,
288-
givens=temp_givens))
288+
givens=temp_givens, size=size))
289289
stored.add(next_.name)
290290
except theano.gof.fg.MissingInputError:
291291
# The node failed, so we must add the node's parents to
@@ -297,7 +297,7 @@ def draw_values(params, point=None):
297297
node not in params])
298298
values = []
299299
for param in params:
300-
values.append(_draw_value(param, point=point, givens=givens.values()))
300+
values.append(_draw_value(param, point=point, givens=givens.values(), size=size))
301301
return values
302302

303303

@@ -326,7 +326,7 @@ def _compile_theano_function(param, vars, givens=None):
326326
allow_input_downcast=True)
327327

328328

329-
def _draw_value(param, point=None, givens=None):
329+
def _draw_value(param, point=None, givens=None, size=None):
330330
"""Draw a random value from a distribution or return a constant.
331331
332332
Parameters
@@ -342,6 +342,8 @@ def _draw_value(param, point=None, givens=None):
342342
givens : dict, optional
343343
A dictionary from theano variables to their values. These values
344344
are used to evaluate `param` if it is a theano variable.
345+
size : int, optional
346+
Number of samples
345347
"""
346348
if isinstance(param, numbers.Number):
347349
return param

pymc3/tests/test_distributions_random.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def test_random_sample_returns_nd_array(self):
9797
assert isinstance(mu, np.ndarray)
9898
assert isinstance(tau, np.ndarray)
9999

100+
def test_random_sample_returns_correctly(self):
101+
# Based on what we discovered in #GH2909
102+
with pm.Model():
103+
a = pm.Uniform('a', lower=0, upper=1, shape=10)
104+
b = pm.Binomial('b', n=1, p=a, shape=10)
105+
array_of_uniform = a.random(size=10000).mean(axis=0)
106+
array_of_binomial = b.random(size=10000).mean(axis=0)
107+
npt.assert_allclose(array_of_uniform, [0.49886929, 0.49949713, 0.49946077, 0.49922606, 0.49927498, 0.50003914,
108+
0.49980687, 0.50180495, 0.500905, 0.50035121], rtol=1e-2, atol=0)
109+
npt.assert_allclose(array_of_binomial, [0.7232, 0.131 , 0.9457, 0.8279, 0.2911, 0.8686, 0.57 , 0.9184,
110+
0.8177, 0.1625], rtol=1e-2, atol=0)
111+
assert isinstance(array_of_binomial, np.ndarray)
112+
assert isinstance(array_of_uniform, np.ndarray)
113+
100114

101115
class BaseTestCases(object):
102116
class BaseTestCase(SeededTest):

0 commit comments

Comments
 (0)