2626from pymc .distributions import distribution , multivariate
2727from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
2828from pymc .distributions .dist_math import check_parameters
29+ from pymc .distributions .distribution import moment
2930from pymc .distributions .logprob import ignore_logprob , logp
3031from pymc .distributions .shape_utils import rv_size_is_none , to_tuple
3132from pymc .util import check_dist_not_registered
@@ -131,7 +132,9 @@ def rng_fn(
131132 else :
132133 dist_shape = (* size , int (steps ))
133134
134- innovations = rng .normal (loc = mu , scale = sigma , size = dist_shape )
135+ # Add one dimension to the right, so that mu and sigma broadcast safely along
136+ # the steps dimension
137+ innovations = rng .normal (loc = mu [..., None ], scale = sigma [..., None ], size = dist_shape )
135138 grw = np .concatenate ([init [..., None ], innovations ], axis = - 1 )
136139 return np .cumsum (grw , axis = - 1 )
137140
@@ -211,6 +214,14 @@ def dist(
211214
212215 return super ().dist ([mu , sigma , init , steps ], size = size , ** kwargs )
213216
217+ def moment (rv , size , mu , sigma , init , steps ):
218+ grw_moment = at .zeros_like (rv )
219+ grw_moment = at .set_subtensor (grw_moment [..., 0 ], moment (init ))
220+ # Add one dimension to the right, so that mu broadcasts safely along the steps
221+ # dimension
222+ grw_moment = at .set_subtensor (grw_moment [..., 1 :], mu [..., None ])
223+ return at .cumsum (grw_moment , axis = - 1 )
224+
214225 def logp (
215226 value : at .Variable ,
216227 mu : at .Variable ,
@@ -225,7 +236,9 @@ def logp(
225236
226237 # Make time series stationary around the mean value
227238 stationary_series = value [..., 1 :] - value [..., :- 1 ]
228- series_logp = logp (Normal .dist (mu , sigma ), stationary_series )
239+ # Add one dimension to the right, so that mu and sigma broadcast safely along
240+ # the steps dimension
241+ series_logp = logp (Normal .dist (mu [..., None ], sigma [..., None ]), stationary_series )
229242
230243 return check_parameters (
231244 init_logp + series_logp .sum (axis = - 1 ),
0 commit comments