@@ -74,17 +74,46 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self):
7474 def test_dependent_initvals (self ):
7575 with pm .Model () as pmodel :
7676 L = pm .Uniform ("L" , 0 , 1 , initval = 0.5 )
77- B = pm .Uniform ("B" , lower = L , upper = 2 , initval = 1.25 )
77+ U = pm .Uniform ("U" , lower = 9 , upper = 10 , initval = 9.5 )
78+ B1 = pm .Uniform ("B1" , lower = L , upper = U , initval = 5 )
79+ B2 = pm .Uniform ("B2" , lower = L , upper = U , initval = (L + U ) / 2 )
7880 ip = pmodel .recompute_initial_point (seed = 0 )
7981 assert ip ["L_interval__" ] == 0
80- assert ip ["B_interval__" ] == 0
82+ assert ip ["U_interval__" ] == 0
83+ assert ip ["B1_interval__" ] == 0
84+ assert ip ["B2_interval__" ] == 0
8185
8286 # Modify initval of L and re-evaluate
83- pmodel .initial_values [L ] = 0 .9
87+ pmodel .initial_values [U ] = 9 .9
8488 ip = pmodel .recompute_initial_point (seed = 0 )
85- assert ip ["B_interval__" ] < 0
89+ assert ip ["B1_interval__" ] < 0
90+ assert ip ["B2_interval__" ] == 0
8691 pass
8792
93+ def test_nested_initvals (self ):
94+ # See issue #5168
95+ with pm .Model () as pmodel :
96+ one = pm .LogNormal ("one" , mu = np .log (1 ), sd = 1e-5 , initval = "prior" )
97+ two = pm .Lognormal ("two" , mu = np .log (one * 2 ), sd = 1e-5 , initval = "prior" )
98+ three = pm .LogNormal ("three" , mu = np .log (two * 2 ), sd = 1e-5 , initval = "prior" )
99+ four = pm .LogNormal ("four" , mu = np .log (three * 2 ), sd = 1e-5 , initval = "prior" )
100+ five = pm .LogNormal ("five" , mu = np .log (four * 2 ), sd = 1e-5 , initval = "prior" )
101+ six = pm .LogNormal ("six" , mu = np .log (five * 2 ), sd = 1e-5 , initval = "prior" )
102+
103+ ip_vals = list (make_initial_point_fn (model = pmodel , return_transformed = True )(0 ).values ())
104+ assert np .allclose (np .exp (ip_vals ), [1 , 2 , 4 , 8 , 16 , 32 ], rtol = 1e-3 )
105+
106+ ip_vals = list (make_initial_point_fn (model = pmodel , return_transformed = False )(0 ).values ())
107+ assert np .allclose (ip_vals , [1 , 2 , 4 , 8 , 16 , 32 ], rtol = 1e-3 )
108+
109+ pmodel .initial_values [four ] = 1
110+
111+ ip_vals = list (make_initial_point_fn (model = pmodel , return_transformed = True )(0 ).values ())
112+ assert np .allclose (np .exp (ip_vals ), [1 , 2 , 4 , 1 , 2 , 4 ], rtol = 1e-3 )
113+
114+ ip_vals = list (make_initial_point_fn (model = pmodel , return_transformed = False )(0 ).values ())
115+ assert np .allclose (ip_vals , [1 , 2 , 4 , 1 , 2 , 4 ], rtol = 1e-3 )
116+
88117 def test_initval_resizing (self ):
89118 with pm .Model () as pmodel :
90119 data = aesara .shared (np .arange (4 ))
0 commit comments