Skip to content

Commit e85ecf3

Browse files
author
Michal Raczycki
committed
adapted tests to use n_zerosum_axes
1 parent bd1daba commit e85ecf3

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,16 +1014,16 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10141014
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
10151015

10161016
@pytest.mark.parametrize(
1017-
"shape, zerosum_axes, expected",
1017+
"shape, n_zerosum_axes, expected",
10181018
[
10191019
((2, 5), None, np.zeros((2, 5))),
10201020
((2, 5, 6), 2, np.zeros((2, 5, 6))),
10211021
((2, 5, 6), 3, np.zeros((2, 5, 6))),
10221022
],
10231023
)
1024-
def test_zerosum_normal_moment(self, shape, zerosum_axes, expected):
1024+
def test_zerosum_normal_moment(self, shape, n_zerosum_axes, expected):
10251025
with pm.Model() as model:
1026-
pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes)
1026+
pm.ZeroSumNormal("x", shape=shape, n_zerosum_axes=n_zerosum_axes)
10271027
assert_moment_is_expected(model, expected)
10281028

10291029
@pytest.mark.parametrize(
@@ -1405,16 +1405,16 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
14051405
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14061406

14071407
@pytest.mark.parametrize(
1408-
"dims, zerosum_axes",
1408+
"dims, n_zerosum_axes",
14091409
[
14101410
(("regions", "answers"), None),
14111411
(("regions", "answers"), 1),
14121412
(("regions", "answers"), 2),
14131413
],
14141414
)
1415-
def test_zsn_dims(self, dims, zerosum_axes):
1415+
def test_zsn_dims(self, dims, n_zerosum_axes):
14161416
with pm.Model(coords=self.coords) as m:
1417-
v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1417+
v = pm.ZeroSumNormal("v", dims=dims, n_zerosum_axes=n_zerosum_axes)
14181418
s = pm.sample(10, chains=1, tune=100)
14191419

14201420
# to test forward graph
@@ -1428,24 +1428,24 @@ def test_zsn_dims(self, dims, zerosum_axes):
14281428
)
14291429

14301430
ndim_supp = v.owner.op.ndim_supp
1431-
zerosum_axes = np.arange(-ndim_supp, 0)
1431+
n_zerosum_axes = np.arange(-ndim_supp, 0)
14321432
nonzero_axes = np.arange(v.ndim - ndim_supp)
14331433
for samples in [
14341434
s.posterior.v,
14351435
random_samples,
14361436
]:
1437-
self.assert_zerosum_axes(samples, zerosum_axes)
1437+
self.assert_zerosum_axes(samples, n_zerosum_axes)
14381438
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14391439

14401440
@pytest.mark.parametrize(
1441-
"zerosum_axes",
1441+
"n_zerosum_axes",
14421442
(None, 1, 2),
14431443
)
1444-
def test_zsn_shape(self, zerosum_axes):
1444+
def test_zsn_shape(self, n_zerosum_axes):
14451445
shape = (len(self.coords["regions"]), len(self.coords["answers"]))
14461446

14471447
with pm.Model(coords=self.coords) as m:
1448-
v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes)
1448+
v = pm.ZeroSumNormal("v", shape=shape, n_zerosum_axes=n_zerosum_axes)
14491449
s = pm.sample(10, chains=1, tune=100)
14501450

14511451
# to test forward graph
@@ -1459,17 +1459,17 @@ def test_zsn_shape(self, zerosum_axes):
14591459
)
14601460

14611461
ndim_supp = v.owner.op.ndim_supp
1462-
zerosum_axes = np.arange(-ndim_supp, 0)
1462+
n_zerosum_axes = np.arange(-ndim_supp, 0)
14631463
nonzero_axes = np.arange(v.ndim - ndim_supp)
14641464
for samples in [
14651465
s.posterior.v,
14661466
random_samples,
14671467
]:
1468-
self.assert_zerosum_axes(samples, zerosum_axes)
1468+
self.assert_zerosum_axes(samples, n_zerosum_axes)
14691469
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14701470

14711471
@pytest.mark.parametrize(
1472-
"error, match, shape, support_shape, zerosum_axes",
1472+
"error, match, shape, support_shape, n_zerosum_axes",
14731473
[
14741474
(
14751475
ValueError,
@@ -1485,14 +1485,14 @@ def test_zsn_shape(self, zerosum_axes):
14851485
(3, 4),
14861486
(3, 4),
14871487
None,
1488-
), # doesn't work because zerosum_axes = 1 by default
1488+
), # doesn't work because n_zerosum_axes = 1 by default
14891489
],
14901490
)
1491-
def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1491+
def test_zsn_fail_axis(self, error, match, shape, support_shape, n_zerosum_axes):
14921492
with pytest.raises(error, match=match):
14931493
with pm.Model() as m:
14941494
_ = pm.ZeroSumNormal(
1495-
"v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes
1495+
"v", shape=shape, support_shape=support_shape, n_zerosum_axes=n_zerosum_axes
14961496
)
14971497

14981498
@pytest.mark.parametrize(
@@ -1504,35 +1504,35 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
15041504
)
15051505
def test_zsn_support_shape(self, shape, support_shape):
15061506
with pm.Model() as m:
1507-
v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2)
1507+
v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, n_zerosum_axes=2)
15081508

15091509
random_samples = pm.draw(v, draws=10)
1510-
zerosum_axes = np.arange(-2, 0)
1511-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1510+
n_zerosum_axes = np.arange(-2, 0)
1511+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15121512

15131513
@pytest.mark.parametrize(
1514-
"zerosum_axes",
1514+
"n_zerosum_axes",
15151515
[1, 2],
15161516
)
1517-
def test_zsn_change_dist_size(self, zerosum_axes):
1518-
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes)
1517+
def test_zsn_change_dist_size(self, n_zerosum_axes):
1518+
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), n_zerosum_axes=n_zerosum_axes)
15191519
random_samples = pm.draw(base_dist, draws=100)
15201520

1521-
zerosum_axes = np.arange(-zerosum_axes, 0)
1522-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1521+
n_zerosum_axes = np.arange(-n_zerosum_axes, 0)
1522+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15231523

15241524
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
15251525
try:
15261526
assert new_dist.eval().shape == (5, 3, 9)
15271527
except AssertionError:
15281528
assert new_dist.eval().shape == (5, 3, 4, 9)
15291529
random_samples = pm.draw(new_dist, draws=100)
1530-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1530+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15311531

15321532
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True)
15331533
assert new_dist.eval().shape == (5, 3, 4, 9)
15341534
random_samples = pm.draw(new_dist, draws=100)
1535-
self.assert_zerosum_axes(random_samples, zerosum_axes)
1535+
self.assert_zerosum_axes(random_samples, n_zerosum_axes)
15361536

15371537
@pytest.mark.parametrize(
15381538
"sigma, n",
@@ -1551,15 +1551,15 @@ def test_zsn_variance(self, sigma, n):
15511551
np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4)
15521552

15531553
@pytest.mark.parametrize(
1554-
"sigma, shape, zerosum_axes, mvn_axes",
1554+
"sigma, shape, n_zerosum_axes, mvn_axes",
15551555
[
15561556
(5, 3, None, [-1]),
15571557
(2, 6, None, [-1]),
15581558
(5, (7, 3), None, [-1]),
15591559
(5, (2, 7, 3), 2, [1, 2]),
15601560
],
15611561
)
1562-
def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes):
1562+
def test_zsn_logp(self, sigma, shape, n_zerosum_axes, mvn_axes):
15631563
def logp_norm(value, sigma, axes):
15641564
"""
15651565
Special case of the MvNormal, that's equivalent to the ZSN.
@@ -1588,7 +1588,7 @@ def logp_norm(value, sigma, axes):
15881588

15891589
return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf)
15901590

1591-
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes)
1591+
zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, n_zerosum_axes=n_zerosum_axes)
15921592
zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval()
15931593
mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)
15941594

0 commit comments

Comments
 (0)