@@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node):
17001700 return False
17011701
17021702 for client , idx in clients :
1703- if isinstance (client .op , Output ):
1703+ client_op = client .op
1704+ if isinstance (client_op , Output ):
17041705 # If the output is a constant, it will have to be deepcopied
17051706 # each time the function is called. So we do not fold.
17061707 return False
1707- # Allow alloc to be lifted out of Elemwise before constant folding it
1708- elif isinstance (client . op , Elemwise ):
1709- return None
1708+ # Op's through which Alloc can be lifted
1709+ elif isinstance (client_op , Elemwise | DimShuffle | Alloc | Join ):
1710+ return False
17101711 # Same for Blockwise, unless it has no batch_dims
1711- elif isinstance (client . op , Blockwise ) and client .op .batch_ndim (client ):
1712- return None
1712+ elif isinstance (client_op , Blockwise ) and client .op .batch_ndim (client ):
1713+ return False
17131714 elif (
17141715 # The following ops work inplace of their input id 0.
17151716 idx == 0
17161717 and isinstance (
1717- client . op ,
1718+ client_op ,
17181719 pytensor .tensor .subtensor .IncSubtensor
17191720 | pytensor .tensor .subtensor .AdvancedIncSubtensor1
17201721 | pytensor .tensor .subtensor .AdvancedIncSubtensor
@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
20352036 _x = as_tensor_variable (x )
20362037
20372038 if axes is None :
2038- axes = list (range ((_x .type .ndim - 1 ), - 1 , - 1 ))
2039+ axes = tuple (range ((_x .type .ndim - 1 ), - 1 , - 1 ))
2040+
2041+ if tuple (axes ) == tuple (range (len (axes ))):
2042+ # No-op
2043+ return _x
2044+
20392045 ret = DimShuffle (tuple (s == 1 for s in _x .type .shape ), axes )(_x )
20402046
2041- if _x .name and axes == list (range ((_x .type .ndim - 1 ), - 1 , - 1 )):
2047+ if _x .name and axes == tuple (range ((_x .type .ndim - 1 ), - 1 , - 1 )):
20422048 ret .name = _x .name + ".T"
20432049
20442050 return ret
@@ -3950,6 +3956,10 @@ def moveaxis(
39503956 source = normalize_axis_tuple (source , a .ndim , "source" )
39513957 destination = normalize_axis_tuple (destination , a .ndim , "destination" )
39523958
3959+ if source == destination :
3960+ # It's a no-op
3961+ return a
3962+
39533963 if len (source ) != len (destination ):
39543964 raise ValueError (
39553965 "`source` and `destination` arguments must have the same number of elements"
@@ -4260,9 +4270,7 @@ def atleast_Nd(
42604270atleast_3d = partial (atleast_Nd , n = 3 )
42614271
42624272
4263- def expand_dims (
4264- a : np .ndarray | TensorVariable , axis : tuple [int , ...]
4265- ) -> TensorVariable :
4273+ def expand_dims (a : np .ndarray | TensorVariable , axis : Sequence [int ]) -> TensorVariable :
42664274 """Expand the shape of an array.
42674275
42684276 Insert a new axis that will appear at the `axis` position in the expanded
@@ -4281,7 +4289,7 @@ def expand_dims(
42814289 """
42824290 a = as_tensor (a )
42834291
4284- if not isinstance (axis , tuple | list ):
4292+ if not isinstance (axis , Sequence ):
42854293 axis = (axis ,)
42864294
42874295 out_ndim = len (axis ) + a .ndim
0 commit comments