Skip to content

Commit 9b311bf

Browse files
committed
add 2d cutpoints and positive sigma
1 parent 466a941 commit 9b311bf

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

pymc/tests/test_distributions_random.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,7 +1694,9 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
16941694
(0, [-2.0, 0, 2.0], (4,)),
16951695
([-1], [-2.0, 0, 2.0], (1, 4)),
16961696
([1.0, -2.0], [-1.0, 0, 1.0], (2, 4)),
1697-
([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], (2, 3, 4)),
1697+
(np.zeros((3, 2)), [-2.0, 0, 1.0], (3, 2, 4)),
1698+
(np.ones((5, 2)), [[-2.0, 0, 1.0], [-1.0, 0, 1.0]], (5, 2, 4)),
1699+
(np.ones((3, 5, 2)), [[-2.0, 0, 1.0], [-1.0, 0, 1.0]], (3, 5, 2, 4)),
16981700
],
16991701
)
17001702
def test_shape_inputs(self, eta, cutpoints, expected):
@@ -1722,16 +1724,12 @@ class TestOrderedProbit(BaseTestDistributionRandom):
17221724
"eta, cutpoints, sigma, expected",
17231725
[
17241726
(0, [-2.0, 0, 2.0], 1.0, (4,)),
1725-
([-1], [-2.0, 0, 2.0], [2.0], (1, 4)),
1727+
([-1], [-1.0, 0, 2.0], [2.0], (1, 4)),
17261728
([1.0, -2.0], [-1.0, 0, 1.0], 1.0, (2, 4)),
1727-
([1.0, -2.0, 3.0], [-2.0, 0, 2.0], [-1.0, -2.0, 5.0], (3, 4)),
1728-
([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], [-1.0, -2.0, 5.0], (2, 3, 4)),
1729-
(
1730-
[[1.0, -2.0, 3.0], [1.0, 2.0, -4.0]],
1731-
[-2.0, 0, 1.0],
1732-
[[0.0, 2.0, -4.0], [-1.0, 1.0, 3.0]],
1733-
(2, 3, 4),
1734-
),
1729+
([1.0, -2.0, 3.0], [-1.0, 0, 2.0], np.ones((1, 3)), (1, 3, 4)),
1730+
(np.zeros((2, 3)), [-2.0, 0, 1.0], [1.0, 2.0, 5.0], (2, 3, 4)),
1731+
(np.ones((2, 3)), [-1.0, 0, 1.0], np.ones((2, 3)), (2, 3, 4)),
1732+
(np.zeros((5, 2)), [[-2, 0, 1], [-1, 0, 1]], np.ones((2, 5, 2)), (2, 5, 2, 4)),
17351733
],
17361734
)
17371735
def test_shape_inputs(self, eta, cutpoints, sigma, expected):

0 commit comments

Comments
 (0)