@@ -86,10 +86,10 @@ def test_gaussian_random_walk_init_dist_shape(self, init):
8686 grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , size = (5 ,))
8787 assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == (5 ,)
8888
89- grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = 1 )
89+ grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = 2 )
9090 assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == ()
9191
92- grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = (5 , 1 ))
92+ grw = pm .GaussianRandomWalk .dist (mu = 0 , sigma = 1 , steps = 1 , init = init , shape = (5 , 2 ))
9393 assert tuple (grw .owner .inputs [- 2 ].shape .eval ()) == (5 ,)
9494
9595 grw = pm .GaussianRandomWalk .dist (mu = [0 , 0 ], sigma = 1 , steps = 1 , init = init )
@@ -113,6 +113,21 @@ def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
113113 assert tuple (grw .shape .eval ()) == (2 , 3 , 5 )
114114 assert grw .eval ().shape == (2 , 3 , 5 )
115115
116+ @pytest .mark .parametrize ("shape" , ((6 ,), (3 , 6 )))
117+ def test_inferred_steps_from_shape (self , shape ):
118+ x = GaussianRandomWalk .dist (shape = shape )
119+ steps = x .owner .inputs [- 1 ]
120+ assert steps .eval () == 5
121+
122+ @pytest .mark .parametrize ("shape" , (None , (5 , ...)))
123+ def test_missing_steps (self , shape ):
124+ with pytest .raises (ValueError , match = "Must specify steps or shape parameter" ):
125+ GaussianRandomWalk .dist (shape = shape )
126+
127+ def test_inconsistent_steps_and_shape (self ):
128+ with pytest .raises (AssertionError , match = "Steps do not match last shape dimension" ):
129+ x = GaussianRandomWalk .dist (steps = 12 , shape = 45 )
130+
116131 @pytest .mark .parametrize (
117132 "init" ,
118133 [
0 commit comments