@@ -1050,8 +1050,32 @@ class TestMvNormalCov(BaseTestDistribution):
10501050 "check_pymc_params_match_rv_op" ,
10511051 "check_pymc_draws_match_reference" ,
10521052 "check_rv_size" ,
1053+ "check_mu_broadcast_helper" ,
10531054 ]
10541055
1056+ def check_mu_broadcast_helper (self ):
1057+ """Test that mu is broadcasted to the shape of cov"""
1058+ x = pm .MvNormal .dist (mu = 1 , cov = np .eye (3 ))
1059+ mu = x .owner .inputs [3 ]
1060+ assert mu .eval ().shape == (3 ,)
1061+
1062+ x = pm .MvNormal .dist (mu = np .ones (1 ), cov = np .eye (3 ))
1063+ mu = x .owner .inputs [3 ]
1064+ assert mu .eval ().shape == (3 ,)
1065+
1066+ x = pm .MvNormal .dist (mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1067+ mu = x .owner .inputs [3 ]
1068+ assert mu .eval ().shape == (1 , 3 )
1069+
1070+ x = pm .MvNormal .dist (mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1071+ mu = x .owner .inputs [3 ]
1072+ assert mu .eval ().shape == (10 , 3 )
1073+
1074+ # Cov is artificually limited to being 2D
1075+ # x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1076+ # mu = x.owner.inputs[3]
1077+ # assert mu.eval().shape == (10, 2, 3)
1078+
10551079
10561080class TestMvNormalChol (BaseTestDistribution ):
10571081 pymc_dist = pm .MvNormal
@@ -1111,6 +1135,7 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
11111135 "check_pymc_draws_match_reference" ,
11121136 "check_rv_size" ,
11131137 "check_errors" ,
1138+ "check_mu_broadcast_helper" ,
11141139 ]
11151140
11161141 def check_errors (self ):
@@ -1124,6 +1149,29 @@ def check_errors(self):
11241149 cov = np .full ((2 , 2 ), np .ones (2 )),
11251150 )
11261151
1152+ def check_mu_broadcast_helper (self ):
1153+ """Test that mu is broadcasted to the shape of cov"""
1154+ x = pm .MvStudentT .dist (nu = 4 , mu = 1 , cov = np .eye (3 ))
1155+ mu = x .owner .inputs [4 ]
1156+ assert mu .eval ().shape == (3 ,)
1157+
1158+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), cov = np .eye (3 ))
1159+ mu = x .owner .inputs [4 ]
1160+ assert mu .eval ().shape == (3 ,)
1161+
1162+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1163+ mu = x .owner .inputs [4 ]
1164+ assert mu .eval ().shape == (1 , 3 )
1165+
1166+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1167+ mu = x .owner .inputs [4 ]
1168+ assert mu .eval ().shape == (10 , 3 )
1169+
1170+ # Cov is artificually limited to being 2D
1171+ # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1172+ # mu = x.owner.inputs[4]
1173+ # assert mu.eval().shape == (10, 2, 3)
1174+
11271175
11281176class TestMvStudentTChol (BaseTestDistribution ):
11291177 pymc_dist = pm .MvStudentT
0 commit comments