@@ -646,7 +646,12 @@ class Repeat(Op):
646646
647647 __props__ = ("axis" ,)
648648
649- def __init__ (self , axis = None ):
649+ def __init__ (self , axis : int | None = None ):
650+ if axis is not None :
651+ if not isinstance (axis , int ) or axis < 0 :
652+ raise ValueError (
653+ f"Repeat only accepts positive integer axis or None, got { axis } "
654+ )
650655 self .axis = axis
651656
652657 def make_node (self , x , repeats ):
@@ -687,58 +692,64 @@ def make_node(self, x, repeats):
687692 out_shape = list (x .type .shape )
688693 out_shape [self .axis ] = None
689694
690- out_type = TensorType (
691- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692- )
693-
695+ out_type = TensorType (x .dtype , shape = out_shape )
694696 return Apply (self , [x , repeats ], [out_type ()])
695697
696698 def perform (self , node , inputs , output_storage ):
697- x = inputs [0 ]
698- repeats = inputs [1 ]
699- z = output_storage [0 ]
700- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699+ [x , repeats ] = inputs
700+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701701
702702 def connection_pattern (self , node ):
703703 return [[True ], [False ]]
704704
705705 def grad (self , inputs , gout ):
706706 (x , repeats ) = inputs
707707 (gz ,) = gout
708+ axis = self .axis
708709 if repeats .ndim == 0 :
709- if self .axis is None :
710- axis = x .ndim
711- else :
712- if self .axis >= 0 :
713- axis = self .axis + 1
714- else :
715- axis = self .axis + x .ndim + 1
716-
717- shape = [x .shape [k ] for k in range (x .ndim )]
718- shape .insert (axis , repeats )
710+ # When axis is a scalar (same number of reps for all elements),
711+ # We can split the repetitions into their own axis with reshape and sum them back
712+ # to the original element location
713+ sum_axis = x .ndim if axis is None else axis + 1
714+ shape = list (x .shape )
715+ shape .insert (sum_axis , repeats )
716+ gx = gz .reshape (shape ).sum (axis = sum_axis )
719717
720- return [
721- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722- DisconnectedType ()(),
723- ]
724718 elif repeats .ndim == 1 :
725- # For this implementation, we would need to specify the length
726- # of repeats in order to split gz in the right way to sum
727- # the good part.
728- raise NotImplementedError ()
719+ # To sum the gradients that belong to the same repeated x,
720+ # We create a repeated eye and dot product it with the gradient.
721+ axis_size = x .size if axis is None else x .shape [axis ]
722+ tiled_eye = repeat (
723+ ptb .eye (axis_size ), repeats , axis = 0
724+ ) # A sparse repeat would be neat
725+
726+ if axis is None :
727+ gx = gz @ tiled_eye
728+ # Undo the ravelling when axis=None
729+ gx = gx .reshape (x .shape )
730+ else :
731+ # Place gradient axis at end for dot product
732+ gx = ptb .moveaxis (gz , axis , - 1 )
733+ gx = gx @ tiled_eye
734+ # Place gradient back into the correct axis
735+ gx = ptb .moveaxis (gx , - 1 , axis )
736+
729737 else :
730738 raise ValueError ()
731739
740+ return [gx , DisconnectedType ()()]
741+
732742 def infer_shape (self , fgraph , node , ins_shapes ):
733743 i0_shapes = ins_shapes [0 ]
734744 repeats = node .inputs [1 ]
735745 out_shape = list (i0_shapes )
746+ axis = self .axis
736747
737748 # uint64 shape are not supported.
738749 dtype = None
739750 if repeats .dtype in ("uint8" , "uint16" , "uint32" ):
740751 dtype = "int64"
741- if self . axis is None :
752+ if axis is None :
742753 if repeats .ndim == 0 :
743754 if len (i0_shapes ) == 0 :
744755 out_shape = [repeats ]
@@ -751,82 +762,97 @@ def infer_shape(self, fgraph, node, ins_shapes):
751762 out_shape = [pt_sum (repeats , dtype = dtype )]
752763 else :
753764 if repeats .ndim == 0 :
754- out_shape [self . axis ] = out_shape [self . axis ] * repeats
765+ out_shape [axis ] = out_shape [axis ] * repeats
755766 else :
756- out_shape [self . axis ] = pt_sum (repeats , dtype = dtype )
767+ out_shape [axis ] = pt_sum (repeats , dtype = dtype )
757768 return [out_shape ]
758769
759770
760- def repeat (x , repeats , axis = None ):
761- """Repeat elements of an array .
771+ def repeat (a : "TensorLike" , repeats : TensorLike , axis : int or None ) -> TensorVariable :
772+ """Repeat elements of a tensor .
762773
763- It returns an array which has the same shape as `x`, except along the given
764- `axis`. The `axis` parameter is used to specify the axis along which values
765- are repeated. By default, a flattened version of `x` is used.
774+ See `numpy.repeat` for more information.
766775
767- The number of repetitions for each element is `repeats`. `repeats` is
768- broadcasted to fit the length of the given `axis`.
769776
770777 Parameters
771778 ----------
772- x
773- Input data, tensor variable.
774- repeats
775- int, scalar or tensor variable
779+ a: tensor_like
780+ Input tensor
781+ repeats: tensor_like
782+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776783 axis : int, optional
784+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777785
778- See Also
786+ Returns
787+ -------
788+ repeated_tensor: TensorVariable
789+ Output tensor which as the same shape as a, except along the given axis
790+
791+ Examples
779792 --------
780- tensor.tile
793+
794+ .. testcode::
795+
796+ import pytensor.tensor as pt
797+
798+ a = pt.arange(4).reshape((2, 2))
799+ out = pt.repeat(a, repeats=[2, 3], axis=0)
800+ print(out.eval())
801+
802+ .. testoutput::
803+
804+ [[0 1]
805+ [0 1]
806+ [2 3]
807+ [2 3]
808+ [2 3]]
809+
781810
782811 .. versionadded:: 0.6
783812
784813 """
814+ a = ptb .as_tensor_variable (a )
815+
816+ if axis is not None :
817+ axis = normalize_axis_index (axis , a .ndim )
818+
785819 repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786820
787821 if repeats .ndim > 1 :
788822 raise ValueError ("The dimension of repeats should not exceed 1." )
789823
790824 if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791- return Repeat (axis = axis )(x , repeats )
825+ # We only use the Repeat Op for vector repeats
826+ return Repeat (axis = axis )(a , repeats )
792827 else :
793828 if repeats .ndim == 1 :
794829 repeats = repeats [0 ]
795830
796- if x .dtype == "uint64" :
831+ if a .dtype == "uint64" :
797832 raise TypeError ("repeat doesn't support dtype uint64" )
798833
799834 if axis is None :
800835 axis = 0
801- x = x .flatten ()
802- else :
803- if axis >= x .ndim :
804- raise ValueError ("Axis should not exceed x.ndim-1." )
805- if axis < 0 :
806- axis = x .ndim + axis
836+ a = a .flatten ()
807837
808- shape = [ x . shape [ i ] for i in range ( x . ndim )]
838+ repeat_shape = list ( a . shape )
809839
810- # shape_ is the shape of the intermediate tensor which has
840+ # alloc_shape is the shape of the intermediate tensor which has
811841 # an additional dimension comparing to x. We use alloc to
812842 # allocate space for this intermediate tensor to replicate x
813843 # along that additional dimension.
814- shape_ = shape [:]
815- shape_ .insert (axis + 1 , repeats )
844+ alloc_shape = repeat_shape [:]
845+ alloc_shape .insert (axis + 1 , repeats )
816846
817- # shape is now the shape of output, where shape[axis] becomes
847+ # repeat_shape is now the shape of output, where shape[axis] becomes
818848 # shape[axis]*repeats.
819- shape [axis ] = shape [axis ] * repeats
820-
821- # dims_ is the dimension of that intermediate tensor.
822- dims_ = list (np .arange (x .ndim ))
823- dims_ .insert (axis + 1 , "x" )
849+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824850
825851 # After the original tensor is duplicated along the additional
826- # dimension, we reshape it to the expected output shape, and
827- # return the output z.
828- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829- return z
852+ # dimension, we reshape it to the expected output shape
853+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
854+ repeat_shape
855+ )
830856
831857
832858class Bartlett (Op ):
0 commit comments