@@ -1014,16 +1014,16 @@ def test_mv_normal_moment(self, mu, cov, size, expected):
10141014 assert_moment_is_expected (model , expected , check_finite_logp = x .ndim < 3 )
10151015
10161016 @pytest .mark .parametrize (
1017- "shape, zerosum_axes , expected" ,
1017+ "shape, n_zerosum_axes , expected" ,
10181018 [
10191019 ((2 , 5 ), None , np .zeros ((2 , 5 ))),
10201020 ((2 , 5 , 6 ), 2 , np .zeros ((2 , 5 , 6 ))),
10211021 ((2 , 5 , 6 ), 3 , np .zeros ((2 , 5 , 6 ))),
10221022 ],
10231023 )
1024- def test_zerosum_normal_moment (self , shape , zerosum_axes , expected ):
1024+ def test_zerosum_normal_moment (self , shape , n_zerosum_axes , expected ):
10251025 with pm .Model () as model :
1026- pm .ZeroSumNormal ("x" , shape = shape , zerosum_axes = zerosum_axes )
1026+ pm .ZeroSumNormal ("x" , shape = shape , n_zerosum_axes = n_zerosum_axes )
10271027 assert_moment_is_expected (model , expected )
10281028
10291029 @pytest .mark .parametrize (
@@ -1405,16 +1405,16 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
14051405 ).all (), f"{ ax } is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14061406
14071407 @pytest .mark .parametrize (
1408- "dims, zerosum_axes " ,
1408+ "dims, n_zerosum_axes " ,
14091409 [
14101410 (("regions" , "answers" ), None ),
14111411 (("regions" , "answers" ), 1 ),
14121412 (("regions" , "answers" ), 2 ),
14131413 ],
14141414 )
1415- def test_zsn_dims (self , dims , zerosum_axes ):
1415+ def test_zsn_dims (self , dims , n_zerosum_axes ):
14161416 with pm .Model (coords = self .coords ) as m :
1417- v = pm .ZeroSumNormal ("v" , dims = dims , zerosum_axes = zerosum_axes )
1417+ v = pm .ZeroSumNormal ("v" , dims = dims , n_zerosum_axes = n_zerosum_axes )
14181418 s = pm .sample (10 , chains = 1 , tune = 100 )
14191419
14201420 # to test forward graph
@@ -1428,24 +1428,24 @@ def test_zsn_dims(self, dims, zerosum_axes):
14281428 )
14291429
14301430 ndim_supp = v .owner .op .ndim_supp
1431- zerosum_axes = np .arange (- ndim_supp , 0 )
1431+ n_zerosum_axes = np .arange (- ndim_supp , 0 )
14321432 nonzero_axes = np .arange (v .ndim - ndim_supp )
14331433 for samples in [
14341434 s .posterior .v ,
14351435 random_samples ,
14361436 ]:
1437- self .assert_zerosum_axes (samples , zerosum_axes )
1437+ self .assert_zerosum_axes (samples , n_zerosum_axes )
14381438 self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
14391439
14401440 @pytest .mark .parametrize (
1441- "zerosum_axes " ,
1441+ "n_zerosum_axes " ,
14421442 (None , 1 , 2 ),
14431443 )
1444- def test_zsn_shape (self , zerosum_axes ):
1444+ def test_zsn_shape (self , n_zerosum_axes ):
14451445 shape = (len (self .coords ["regions" ]), len (self .coords ["answers" ]))
14461446
14471447 with pm .Model (coords = self .coords ) as m :
1448- v = pm .ZeroSumNormal ("v" , shape = shape , zerosum_axes = zerosum_axes )
1448+ v = pm .ZeroSumNormal ("v" , shape = shape , n_zerosum_axes = n_zerosum_axes )
14491449 s = pm .sample (10 , chains = 1 , tune = 100 )
14501450
14511451 # to test forward graph
@@ -1459,17 +1459,17 @@ def test_zsn_shape(self, zerosum_axes):
14591459 )
14601460
14611461 ndim_supp = v .owner .op .ndim_supp
1462- zerosum_axes = np .arange (- ndim_supp , 0 )
1462+ n_zerosum_axes = np .arange (- ndim_supp , 0 )
14631463 nonzero_axes = np .arange (v .ndim - ndim_supp )
14641464 for samples in [
14651465 s .posterior .v ,
14661466 random_samples ,
14671467 ]:
1468- self .assert_zerosum_axes (samples , zerosum_axes )
1468+ self .assert_zerosum_axes (samples , n_zerosum_axes )
14691469 self .assert_zerosum_axes (samples , nonzero_axes , check_zerosum_axes = False )
14701470
14711471 @pytest .mark .parametrize (
1472- "error, match, shape, support_shape, zerosum_axes " ,
1472+ "error, match, shape, support_shape, n_zerosum_axes " ,
14731473 [
14741474 (
14751475 ValueError ,
@@ -1485,14 +1485,14 @@ def test_zsn_shape(self, zerosum_axes):
14851485 (3 , 4 ),
14861486 (3 , 4 ),
14871487 None ,
1488- ), # doesn't work because zerosum_axes = 1 by default
1488+ ), # doesn't work because n_zerosum_axes = 1 by default
14891489 ],
14901490 )
1491- def test_zsn_fail_axis (self , error , match , shape , support_shape , zerosum_axes ):
1491+ def test_zsn_fail_axis (self , error , match , shape , support_shape , n_zerosum_axes ):
14921492 with pytest .raises (error , match = match ):
14931493 with pm .Model () as m :
14941494 _ = pm .ZeroSumNormal (
1495- "v" , shape = shape , support_shape = support_shape , zerosum_axes = zerosum_axes
1495+ "v" , shape = shape , support_shape = support_shape , n_zerosum_axes = n_zerosum_axes
14961496 )
14971497
14981498 @pytest .mark .parametrize (
@@ -1504,35 +1504,35 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
15041504 )
15051505 def test_zsn_support_shape (self , shape , support_shape ):
15061506 with pm .Model () as m :
1507- v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , zerosum_axes = 2 )
1507+ v = pm .ZeroSumNormal ("v" , shape = shape , support_shape = support_shape , n_zerosum_axes = 2 )
15081508
15091509 random_samples = pm .draw (v , draws = 10 )
1510- zerosum_axes = np .arange (- 2 , 0 )
1511- self .assert_zerosum_axes (random_samples , zerosum_axes )
1510+ n_zerosum_axes = np .arange (- 2 , 0 )
1511+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
15121512
15131513 @pytest .mark .parametrize (
1514- "zerosum_axes " ,
1514+ "n_zerosum_axes " ,
15151515 [1 , 2 ],
15161516 )
1517- def test_zsn_change_dist_size (self , zerosum_axes ):
1518- base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), zerosum_axes = zerosum_axes )
1517+ def test_zsn_change_dist_size (self , n_zerosum_axes ):
1518+ base_dist = pm .ZeroSumNormal .dist (shape = (4 , 9 ), n_zerosum_axes = n_zerosum_axes )
15191519 random_samples = pm .draw (base_dist , draws = 100 )
15201520
1521- zerosum_axes = np .arange (- zerosum_axes , 0 )
1522- self .assert_zerosum_axes (random_samples , zerosum_axes )
1521+ n_zerosum_axes = np .arange (- n_zerosum_axes , 0 )
1522+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
15231523
15241524 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = False )
15251525 try :
15261526 assert new_dist .eval ().shape == (5 , 3 , 9 )
15271527 except AssertionError :
15281528 assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
15291529 random_samples = pm .draw (new_dist , draws = 100 )
1530- self .assert_zerosum_axes (random_samples , zerosum_axes )
1530+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
15311531
15321532 new_dist = change_dist_size (base_dist , new_size = (5 , 3 ), expand = True )
15331533 assert new_dist .eval ().shape == (5 , 3 , 4 , 9 )
15341534 random_samples = pm .draw (new_dist , draws = 100 )
1535- self .assert_zerosum_axes (random_samples , zerosum_axes )
1535+ self .assert_zerosum_axes (random_samples , n_zerosum_axes )
15361536
15371537 @pytest .mark .parametrize (
15381538 "sigma, n" ,
@@ -1551,15 +1551,15 @@ def test_zsn_variance(self, sigma, n):
15511551 np .testing .assert_allclose (empirical_var , theoretical_var , atol = 0.4 )
15521552
15531553 @pytest .mark .parametrize (
1554- "sigma, shape, zerosum_axes , mvn_axes" ,
1554+ "sigma, shape, n_zerosum_axes , mvn_axes" ,
15551555 [
15561556 (5 , 3 , None , [- 1 ]),
15571557 (2 , 6 , None , [- 1 ]),
15581558 (5 , (7 , 3 ), None , [- 1 ]),
15591559 (5 , (2 , 7 , 3 ), 2 , [1 , 2 ]),
15601560 ],
15611561 )
1562- def test_zsn_logp (self , sigma , shape , zerosum_axes , mvn_axes ):
1562+ def test_zsn_logp (self , sigma , shape , n_zerosum_axes , mvn_axes ):
15631563 def logp_norm (value , sigma , axes ):
15641564 """
15651565 Special case of the MvNormal, that's equivalent to the ZSN.
@@ -1588,7 +1588,7 @@ def logp_norm(value, sigma, axes):
15881588
15891589 return np .where (inds , np .sum (- psdet - exp , axis = - 1 ), - np .inf )
15901590
1591- zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , zerosum_axes = zerosum_axes )
1591+ zsn_dist = pm .ZeroSumNormal .dist (sigma = sigma , shape = shape , n_zerosum_axes = n_zerosum_axes )
15921592 zsn_logp = pm .logp (zsn_dist , value = np .zeros (shape )).eval ()
15931593 mvn_logp = logp_norm (value = np .zeros (shape ), sigma = sigma , axes = mvn_axes )
15941594
0 commit comments