@@ -393,14 +393,19 @@ class AR(SymbolicDistribution):
393393
394394 """
395395
396- def __new__ (cls , * args , steps = None , ** kwargs ):
396+ def __new__ (cls , name , rho , * args , steps = None , constant = False , ar_order = None , ** kwargs ):
397+ rhos = at .atleast_1d (at .as_tensor_variable (floatX (rho )))
398+ ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
397399 steps = get_steps (
398400 steps = steps ,
399401 shape = None , # Shape will be checked in `cls.dist`
400402 dims = kwargs .get ("dims" , None ),
401403 observed = kwargs .get ("observed" , None ),
404+ step_shape_offset = ar_order ,
405+ )
406+ return super ().__new__ (
407+ cls , name , rhos , * args , steps = steps , constant = constant , ar_order = ar_order , ** kwargs
402408 )
403- return super ().__new__ (cls , * args , steps = steps , ** kwargs )
404409
405410 @classmethod
406411 def dist (
@@ -426,34 +431,12 @@ def dist(
426431 )
427432 init_dist = kwargs ["init" ]
428433
429- steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ))
434+ ar_order = cls ._get_ar_order (rhos = rhos , constant = constant , ar_order = ar_order )
435+ steps = get_steps (steps = steps , shape = kwargs .get ("shape" , None ), step_shape_offset = ar_order )
430436 if steps is None :
431437 raise ValueError ("Must specify steps or shape parameter" )
432438 steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
433439
434- if ar_order is None :
435- # If ar_order is not specified we do constant folding on the shape of rhos
436- # to retrieve it. For example, this will detect that
437- # Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
438- shape_fg = FunctionGraph (
439- outputs = [rhos .shape [- 1 ]],
440- features = [ShapeFeature ()],
441- clone = True ,
442- )
443- (folded_shape ,) = optimize_graph (shape_fg , custom_opt = topo_constant_folding ).outputs
444- folded_shape = getattr (folded_shape , "data" , None )
445- if folded_shape is None :
446- raise ValueError (
447- "Could not infer ar_order from last dimension of rho. Pass it "
448- "explictily or make sure rho have a static shape"
449- )
450- ar_order = int (folded_shape ) - int (constant )
451- if ar_order < 1 :
452- raise ValueError (
453- "Inferred ar_order is smaller than 1. Increase the last dimension "
454- "of rho or remove constant_term"
455- )
456-
457440 if init_dist is not None :
458441 if not isinstance (init_dist , TensorVariable ) or not isinstance (
459442 init_dist .owner .op , RandomVariable
@@ -477,6 +460,41 @@ def dist(
477460
478461 return super ().dist ([rhos , sigma , init_dist , steps , ar_order , constant ], ** kwargs )
479462
463+ @classmethod
464+ def _get_ar_order (cls , rhos : TensorVariable , ar_order : Optional [int ], constant : bool ) -> int :
465+ """Compute ar_order given inputs
466+
467+ If ar_order is not specified we do constant folding on the shape of rhos
468+ to retrieve it. For example, this will detect that
469+ Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before.
470+
471+ Raises
472+ ------
473+ ValueError
474+ If inferred ar_order cannot be inferred from rhos or if it is less than 1
475+ """
476+ if ar_order is None :
477+ shape_fg = FunctionGraph (
478+ outputs = [rhos .shape [- 1 ]],
479+ features = [ShapeFeature ()],
480+ clone = True ,
481+ )
482+ (folded_shape ,) = optimize_graph (shape_fg , custom_opt = topo_constant_folding ).outputs
483+ folded_shape = getattr (folded_shape , "data" , None )
484+ if folded_shape is None :
485+ raise ValueError (
486+ "Could not infer ar_order from last dimension of rho. Pass it "
487+ "explictily or make sure rho have a static shape"
488+ )
489+ ar_order = int (folded_shape ) - int (constant )
490+ if ar_order < 1 :
491+ raise ValueError (
492+ "Inferred ar_order is smaller than 1. Increase the last dimension "
493+ "of rho or remove constant_term"
494+ )
495+
496+ return ar_order
497+
480498 @classmethod
481499 def num_rngs (cls , * args , ** kwargs ):
482500 return 2
@@ -540,7 +558,7 @@ def step(*args):
540558 fn = step ,
541559 outputs_info = [{"initial" : init_ .T , "taps" : range (- ar_order , 0 )}],
542560 non_sequences = [rhos_bcast_ .T [::- 1 ], sigma_ .T , noise_rng ],
543- n_steps = at . max (( 0 , steps_ - ar_order )) ,
561+ n_steps = steps_ ,
544562 strict = True ,
545563 )
546564 (noise_next_rng ,) = tuple (innov_updates_ .values ())
0 commit comments