Skip to content

Commit 466a941

Browse files
committed
add test_shape_inputs for _OrderedLogistic
1 parent 32f6c89 commit 466a941

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

pymc/tests/test_distributions_random.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,26 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
16881688
"check_rv_size",
16891689
]
16901690

1691+
@pytest.mark.parametrize(
1692+
"eta, cutpoints, expected",
1693+
[
1694+
(0, [-2.0, 0, 2.0], (4,)),
1695+
([-1], [-2.0, 0, 2.0], (1, 4)),
1696+
([1.0, -2.0], [-1.0, 0, 1.0], (2, 4)),
1697+
([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], (2, 3, 4)),
1698+
],
1699+
)
1700+
def test_shape_inputs(self, eta, cutpoints, expected):
1701+
"""
1702+
This test checks when providing different shapes for `eta` parameters.
1703+
"""
1704+
categorical = _OrderedLogistic.dist(
1705+
eta=eta,
1706+
cutpoints=cutpoints,
1707+
)
1708+
p = categorical.owner.inputs[3].eval()
1709+
assert p.shape == expected
1710+
16911711

16921712
class TestOrderedProbit(BaseTestDistributionRandom):
16931713
pymc_dist = _OrderedProbit

0 commit comments

Comments
 (0)