@@ -1688,6 +1688,28 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
16881688 "check_rv_size" ,
16891689 ]
16901690
1691+ @pytest .mark .parametrize (
1692+ "eta, cutpoints, expected" ,
1693+ [
1694+ (0 , [- 2.0 , 0 , 2.0 ], (4 ,)),
1695+ ([- 1 ], [- 2.0 , 0 , 2.0 ], (1 , 4 )),
1696+ ([1.0 , - 2.0 ], [- 1.0 , 0 , 1.0 ], (2 , 4 )),
1697+ (np .zeros ((3 , 2 )), [- 2.0 , 0 , 1.0 ], (3 , 2 , 4 )),
1698+ (np .ones ((5 , 2 )), [[- 2.0 , 0 , 1.0 ], [- 1.0 , 0 , 1.0 ]], (5 , 2 , 4 )),
1699+ (np .ones ((3 , 5 , 2 )), [[- 2.0 , 0 , 1.0 ], [- 1.0 , 0 , 1.0 ]], (3 , 5 , 2 , 4 )),
1700+ ],
1701+ )
1702+ def test_shape_inputs (self , eta , cutpoints , expected ):
1703+ """
1704+ This test checks when providing different shapes for `eta` parameters.
1705+ """
1706+ categorical = _OrderedLogistic .dist (
1707+ eta = eta ,
1708+ cutpoints = cutpoints ,
1709+ )
1710+ p = categorical .owner .inputs [3 ].eval ()
1711+ assert p .shape == expected
1712+
16911713
16921714class TestOrderedProbit (BaseTestDistributionRandom ):
16931715 pymc_dist = _OrderedProbit
@@ -1698,6 +1720,30 @@ class TestOrderedProbit(BaseTestDistributionRandom):
16981720 "check_rv_size" ,
16991721 ]
17001722
1723+ @pytest .mark .parametrize (
1724+ "eta, cutpoints, sigma, expected" ,
1725+ [
1726+ (0 , [- 2.0 , 0 , 2.0 ], 1.0 , (4 ,)),
1727+ ([- 1 ], [- 1.0 , 0 , 2.0 ], [2.0 ], (1 , 4 )),
1728+ ([1.0 , - 2.0 ], [- 1.0 , 0 , 1.0 ], 1.0 , (2 , 4 )),
1729+ ([1.0 , - 2.0 , 3.0 ], [- 1.0 , 0 , 2.0 ], np .ones ((1 , 3 )), (1 , 3 , 4 )),
1730+ (np .zeros ((2 , 3 )), [- 2.0 , 0 , 1.0 ], [1.0 , 2.0 , 5.0 ], (2 , 3 , 4 )),
1731+ (np .ones ((2 , 3 )), [- 1.0 , 0 , 1.0 ], np .ones ((2 , 3 )), (2 , 3 , 4 )),
1732+ (np .zeros ((5 , 2 )), [[- 2 , 0 , 1 ], [- 1 , 0 , 1 ]], np .ones ((2 , 5 , 2 )), (2 , 5 , 2 , 4 )),
1733+ ],
1734+ )
1735+ def test_shape_inputs (self , eta , cutpoints , sigma , expected ):
1736+ """
1737+ This test checks when providing different shapes for `eta` and `sigma` parameters.
1738+ """
1739+ categorical = _OrderedProbit .dist (
1740+ eta = eta ,
1741+ cutpoints = cutpoints ,
1742+ sigma = sigma ,
1743+ )
1744+ p = categorical .owner .inputs [3 ].eval ()
1745+ assert p .shape == expected
1746+
17011747
17021748class TestOrderedMultinomial (BaseTestDistributionRandom ):
17031749 pymc_dist = _OrderedMultinomial
0 commit comments