1919
2020from aesara import scan
2121from aesara .tensor .random .op import RandomVariable
22+ from aesara .tensor .random .utils import normalize_size_param
2223
2324from pymc .aesaraf import change_rv_size , floatX , intX
2425from pymc .distributions import distribution , logprob , multivariate
2526from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
2627from pymc .distributions .dist_math import check_parameters
27- from pymc .distributions .shape_utils import to_tuple
28+ from pymc .distributions .shape_utils import rv_size_is_none , to_tuple
2829from pymc .util import check_dist_not_registered
2930
3031__all__ = [
@@ -54,6 +55,16 @@ def make_node(self, rng, size, dtype, mu, sigma, init, steps):
5455 if not steps .ndim == 0 or not steps .dtype .startswith ("int" ):
5556 raise ValueError ("steps must be an integer scalar (ndim=0)." )
5657
58+ mu = at .as_tensor_variable (mu )
59+ sigma = at .as_tensor_variable (sigma )
60+ init = at .as_tensor_variable (init )
61+
62+ # Resize init distribution
63+ size = normalize_size_param (size )
64+ # If not explicit, size is determined by the shapes of mu, sigma, and init
65+ init_size = size if not rv_size_is_none (size ) else at .broadcast_shape (mu , sigma , init )
66+ init = change_rv_size (init , init_size )
67+
5768 return super ().make_node (rng , size , dtype , mu , sigma , init , steps )
5869
5970 def _supp_shape_from_params (self , dist_params , reop_param_idx = 0 , param_shapes = None ):
@@ -160,15 +171,9 @@ def dist(
160171 raise ValueError ("Must specify steps parameter" )
161172 steps = at .as_tensor_variable (intX (steps ))
162173
163- shape = kwargs .get ("shape" , None )
164- if size is None and shape is None :
165- init_size = None
166- else :
167- init_size = to_tuple (size ) if size is not None else to_tuple (shape )[:- 1 ]
168-
169174 # If no scalar distribution is passed then initialize with a Normal of same mu and sigma
170175 if init is None :
171- init = Normal .dist (mu , sigma , size = init_size )
176+ init = Normal .dist (mu , sigma )
172177 else :
173178 if not (
174179 isinstance (init , at .TensorVariable )
@@ -178,13 +183,6 @@ def dist(
178183 ):
179184 raise TypeError ("init must be a univariate distribution variable" )
180185
181- if init_size is not None :
182- init = change_rv_size (init , init_size )
183- else :
184- # If not explicit, size is determined by the shapes of mu, sigma, and init
185- bcast_shape = at .broadcast_arrays (mu , sigma , init )[0 ].shape
186- init = change_rv_size (init , bcast_shape )
187-
188186 # Ignores logprob of init var because that's accounted for in the logp method
189187 init .tag .ignore_logprob = True
190188
0 commit comments