From b68db1277021a71a32655f48910f3b9fdea7553e Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Tue, 14 Feb 2023 15:08:23 +0100 Subject: [PATCH 1/3] added n_zerosum_axes and added backwards compatibility for previous parameter name --- pymc/distributions/multivariate.py | 82 +++++++++++++++++------------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index b0f15bb0c7..e247ce9553 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2394,7 +2394,7 @@ class ZeroSumNormal(Distribution): ZeroSumNormal distribution, i.e Normal distribution where one or several axes are constrained to sum to zero. By default, the last axis is constrained to sum to zero. - See `zerosum_axes` kwarg for more details. + See `n_zerosum_axes` kwarg for more details. .. math:: @@ -2411,9 +2411,10 @@ class ZeroSumNormal(Distribution): It's actually the standard deviation of the underlying, unconstrained Normal distribution. Defaults to 1 if not specified. For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. - zerosum_axes: int, defaults to 1 + n_zerosum_axes: int, defaults to 1 Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. Defaults to 1, i.e the rightmost axis. + zerosum_axes: int, deprecated please use n_zerosum_axes as its successor dims: sequence of strings, optional Dimension names of the distribution. Works the same as for other PyMC distributions. Necessary if ``shape`` is not passed. @@ -2452,25 +2453,38 @@ class ZeroSumNormal(Distribution): """ rv_type = ZeroSumNormalRV - def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs): + def __new__( + cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs + ): + if zerosum_axes is not None: + n_nezosum_axes = zerosum_axes + warnings.warn( + "The 'zerosum_axes' parameter is deprecated. Use 'n_zerosum_axes' instead.", + DeprecationWarning, + ) if dims is not None or kwargs.get("observed") is not None: - zerosum_axes = cls.check_zerosum_axes(zerosum_axes) + n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes) support_shape = get_support_shape( support_shape=support_shape, shape=None, # Shape will be checked in `cls.dist` dims=dims, observed=kwargs.get("observed", None), - ndim_supp=zerosum_axes, + ndim_supp=n_zerosum_axes, ) return super().__new__( - cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs + cls, + *args, + n_zerosum_axes=n_zerosum_axes, + support_shape=support_shape, + dims=dims, + **kwargs, ) @classmethod - def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): - zerosum_axes = cls.check_zerosum_axes(zerosum_axes) + def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs): + n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes) sigma = at.as_tensor_variable(floatX(sigma)) if sigma.ndim > 0: @@ -2479,41 +2493,41 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): support_shape = get_support_shape( support_shape=support_shape, shape=kwargs.get("shape"), - ndim_supp=zerosum_axes, + ndim_supp=n_zerosum_axes, ) if support_shape is None: - if zerosum_axes > 0: + if n_zerosum_axes > 0: raise ValueError("You must specify dims, shape or support_shape parameter") # TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails # else: # support_shape = () # because it's just a Normal in that case support_shape = at.as_tensor_variable(intX(support_shape)) - assert zerosum_axes == at.get_vector_length( + assert n_zerosum_axes == at.get_vector_length( support_shape - ), "support_shape has to be as long as zerosum_axes" + ), "support_shape has to be as long as n_zerosum_axes" return super().dist( - [sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs + [sigma], n_zerosum_axes=n_zerosum_axes, support_shape=support_shape, **kwargs ) @classmethod - def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: - if zerosum_axes is None: - zerosum_axes = 1 - if not isinstance(zerosum_axes, int): - raise TypeError("zerosum_axes has to be an integer") - if not zerosum_axes > 0: - raise ValueError("zerosum_axes has to be > 0") - return zerosum_axes + def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int: + if n_zerosum_axes is None: + n_zerosum_axes = 1 + if not isinstance(n_zerosum_axes, int): + raise TypeError("n_zerosum_axes has to be an integer") + if not n_zerosum_axes > 0: + raise ValueError("n_zerosum_axes has to be > 0") + return n_zerosum_axes @classmethod - def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): + def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None): shape = to_tuple(size) + tuple(support_shape) normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape)) - if zerosum_axes > normal_dist.ndim: + if n_zerosum_axes > normal_dist.ndim: raise ValueError("Shape of distribution is too small for the number of zerosum axes") normal_dist_, sigma_, support_shape_ = ( @@ -2522,15 +2536,15 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): support_shape.type(), ) - # Zerosum-normaling is achieved by subtracting the mean along the given zerosum_axes + # Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes zerosum_rv_ = normal_dist_ - for axis in range(zerosum_axes): + for axis in range(n_zerosum_axes): zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) return ZeroSumNormalRV( inputs=[normal_dist_, sigma_, support_shape_], outputs=[zerosum_rv_, support_shape_], - ndim_supp=zerosum_axes, + ndim_supp=n_zerosum_axes, )(normal_dist, sigma, support_shape) @@ -2544,7 +2558,7 @@ def change_zerosum_size(op, normal_dist, new_size, expand=False): new_size = tuple(new_size) + old_size return ZeroSumNormal.rv_op( - sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size + sigma=sigma, n_zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size ) @@ -2555,28 +2569,28 @@ def zerosumnormal_moment(op, rv, *rv_inputs): @_default_transform.register(ZeroSumNormalRV) def zerosum_default_transform(op, rv): - zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) - return ZeroSumTransform(zerosum_axes) + n_zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) + return ZeroSumTransform(n_zerosum_axes) @_logprob.register(ZeroSumNormalRV) def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): (value,) = values shape = value.shape - zerosum_axes = op.ndim_supp + n_zerosum_axes = op.ndim_supp - _deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1) + _deg_free_support_shape = at.inc_subtensor(shape[-n_zerosum_axes:], -1) _full_size = at.prod(shape) _degrees_of_freedom = at.prod(_deg_free_support_shape) zerosums = [ at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9)) - for axis in range(zerosum_axes) + for axis in range(n_zerosum_axes) ] out = at.sum( pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size, - axis=tuple(np.arange(-zerosum_axes, 0)), + axis=tuple(np.arange(-n_zerosum_axes, 0)), ) - return check_parameters(out, *zerosums, msg="mean(value, axis=zerosum_axes) = 0") + return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0") From bd1daba7bb3133665791294520511a0281247967 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Mon, 20 Feb 2023 10:53:14 +0100 Subject: [PATCH 2/3] fixed typo that caused the zerosum_axes param not to be saved correctly in the new param --- pymc/distributions/multivariate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index e247ce9553..edba23f3fe 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2427,13 +2427,13 @@ class ZeroSumNormal(Distribution): ``sigma`` has to be a scalar, to ensure the zero-sum constraint. The ability to specify a vector of ``sigma`` may be added in future versions. - ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``, + ``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``, just use ``pm.Normal``. Examples -------- Define a `ZeroSumNormal` variable, with `sigma=1` and - `zerosum_axes=1` by default:: + `n_zerosum_axes=1` by default:: COORDS = { "regions": ["a", "b", "c"], @@ -2445,11 +2445,11 @@ class ZeroSumNormal(Distribution): with pm.Model(coords=COORDS) as m: # the zero sum axes will be 'answers' and 'regions' - v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2) with pm.Model(coords=COORDS) as m: # the zero sum axes will be the last two - v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) + v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2) """ rv_type = ZeroSumNormalRV @@ -2457,7 +2457,7 @@ def __new__( cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs ): if zerosum_axes is not None: - n_nezosum_axes = zerosum_axes + n_zerosum_axes = zerosum_axes warnings.warn( "The 'zerosum_axes' parameter is deprecated. Use 'n_zerosum_axes' instead.", DeprecationWarning, From e85ecf3631311c78bfc169987faefde220d59515 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Mon, 20 Feb 2023 11:24:34 +0100 Subject: [PATCH 3/3] adapted tests to use n_zerosum_axes --- pymc/tests/distributions/test_multivariate.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 3c5a0099a0..dc69133331 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1014,16 +1014,16 @@ def test_mv_normal_moment(self, mu, cov, size, expected): assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) @pytest.mark.parametrize( - "shape, zerosum_axes, expected", + "shape, n_zerosum_axes, expected", [ ((2, 5), None, np.zeros((2, 5))), ((2, 5, 6), 2, np.zeros((2, 5, 6))), ((2, 5, 6), 3, np.zeros((2, 5, 6))), ], ) - def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): + def test_zerosum_normal_moment(self, shape, n_zerosum_axes, expected): with pm.Model() as model: - pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes) + pm.ZeroSumNormal("x", shape=shape, n_zerosum_axes=n_zerosum_axes) assert_moment_is_expected(model, expected) @pytest.mark.parametrize( @@ -1405,16 +1405,16 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes= ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." @pytest.mark.parametrize( - "dims, zerosum_axes", + "dims, n_zerosum_axes", [ (("regions", "answers"), None), (("regions", "answers"), 1), (("regions", "answers"), 2), ], ) - def test_zsn_dims(self, dims, zerosum_axes): + def test_zsn_dims(self, dims, n_zerosum_axes): with pm.Model(coords=self.coords) as m: - v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + v = pm.ZeroSumNormal("v", dims=dims, n_zerosum_axes=n_zerosum_axes) s = pm.sample(10, chains=1, tune=100) # to test forward graph @@ -1428,24 +1428,24 @@ def test_zsn_dims(self, dims, zerosum_axes): ) ndim_supp = v.owner.op.ndim_supp - zerosum_axes = np.arange(-ndim_supp, 0) + n_zerosum_axes = np.arange(-ndim_supp, 0) nonzero_axes = np.arange(v.ndim - ndim_supp) for samples in [ s.posterior.v, random_samples, ]: - self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, n_zerosum_axes) self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) @pytest.mark.parametrize( - "zerosum_axes", + "n_zerosum_axes", (None, 1, 2), ) - def test_zsn_shape(self, zerosum_axes): + def test_zsn_shape(self, n_zerosum_axes): shape = (len(self.coords["regions"]), len(self.coords["answers"])) with pm.Model(coords=self.coords) as m: - v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes) + v = pm.ZeroSumNormal("v", shape=shape, n_zerosum_axes=n_zerosum_axes) s = pm.sample(10, chains=1, tune=100) # to test forward graph @@ -1459,17 +1459,17 @@ def test_zsn_shape(self, zerosum_axes): ) ndim_supp = v.owner.op.ndim_supp - zerosum_axes = np.arange(-ndim_supp, 0) + n_zerosum_axes = np.arange(-ndim_supp, 0) nonzero_axes = np.arange(v.ndim - ndim_supp) for samples in [ s.posterior.v, random_samples, ]: - self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, n_zerosum_axes) self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) @pytest.mark.parametrize( - "error, match, shape, support_shape, zerosum_axes", + "error, match, shape, support_shape, n_zerosum_axes", [ ( ValueError, @@ -1485,14 +1485,14 @@ def test_zsn_shape(self, zerosum_axes): (3, 4), (3, 4), None, - ), # doesn't work because zerosum_axes = 1 by default + ), # doesn't work because n_zerosum_axes = 1 by default ], ) - def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): + def test_zsn_fail_axis(self, error, match, shape, support_shape, n_zerosum_axes): with pytest.raises(error, match=match): with pm.Model() as m: _ = pm.ZeroSumNormal( - "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes + "v", shape=shape, support_shape=support_shape, n_zerosum_axes=n_zerosum_axes ) @pytest.mark.parametrize( @@ -1504,22 +1504,22 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): ) def test_zsn_support_shape(self, shape, support_shape): with pm.Model() as m: - v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2) + v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, n_zerosum_axes=2) random_samples = pm.draw(v, draws=10) - zerosum_axes = np.arange(-2, 0) - self.assert_zerosum_axes(random_samples, zerosum_axes) + n_zerosum_axes = np.arange(-2, 0) + self.assert_zerosum_axes(random_samples, n_zerosum_axes) @pytest.mark.parametrize( - "zerosum_axes", + "n_zerosum_axes", [1, 2], ) - def test_zsn_change_dist_size(self, zerosum_axes): - base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes) + def test_zsn_change_dist_size(self, n_zerosum_axes): + base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), n_zerosum_axes=n_zerosum_axes) random_samples = pm.draw(base_dist, draws=100) - zerosum_axes = np.arange(-zerosum_axes, 0) - self.assert_zerosum_axes(random_samples, zerosum_axes) + n_zerosum_axes = np.arange(-n_zerosum_axes, 0) + self.assert_zerosum_axes(random_samples, n_zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) try: @@ -1527,12 +1527,12 @@ def test_zsn_change_dist_size(self, zerosum_axes): except AssertionError: assert new_dist.eval().shape == (5, 3, 4, 9) random_samples = pm.draw(new_dist, draws=100) - self.assert_zerosum_axes(random_samples, zerosum_axes) + self.assert_zerosum_axes(random_samples, n_zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) assert new_dist.eval().shape == (5, 3, 4, 9) random_samples = pm.draw(new_dist, draws=100) - self.assert_zerosum_axes(random_samples, zerosum_axes) + self.assert_zerosum_axes(random_samples, n_zerosum_axes) @pytest.mark.parametrize( "sigma, n", @@ -1551,7 +1551,7 @@ def test_zsn_variance(self, sigma, n): np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4) @pytest.mark.parametrize( - "sigma, shape, zerosum_axes, mvn_axes", + "sigma, shape, n_zerosum_axes, mvn_axes", [ (5, 3, None, [-1]), (2, 6, None, [-1]), @@ -1559,7 +1559,7 @@ def test_zsn_variance(self, sigma, n): (5, (2, 7, 3), 2, [1, 2]), ], ) - def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes): + def test_zsn_logp(self, sigma, shape, n_zerosum_axes, mvn_axes): def logp_norm(value, sigma, axes): """ Special case of the MvNormal, that's equivalent to the ZSN. @@ -1588,7 +1588,7 @@ def logp_norm(value, sigma, axes): return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf) - zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes) + zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, n_zerosum_axes=n_zerosum_axes) zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval() mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes)