@@ -216,7 +216,7 @@ def as_index_literal(
216216 )
217217
218218 if not isinstance (idx , Variable ):
219- raise NotScalarConstantError ( )
219+ raise TypeError ( f"Not an index element: { idx } " )
220220
221221 if isinstance (idx .type , NoneTypeT ):
222222 return None
@@ -244,23 +244,26 @@ def get_idx_list(inputs, idx_list):
244244
245245@overload
246246def get_canonical_form_slice (
247- theslice : slice | SliceConstant ,
247+ theslice : slice ,
248248 length : int | np .integer | ScalarVariable | TensorVariable ,
249249) -> tuple [slice , int | ScalarConstant ]: ...
250250
251251
252252@overload
253253def get_canonical_form_slice (
254- theslice : ScalarVariable | TensorVariable ,
254+ theslice : int | np . integer | ScalarVariable | TensorVariable ,
255255 length : int | np .integer | ScalarVariable | TensorVariable ,
256256) -> tuple [ScalarVariable , int ]: ...
257257
258258
259259def get_canonical_form_slice (
260- theslice : slice | Variable ,
260+ theslice : slice | int | np . integer | ScalarVariable | TensorVariable ,
261261 length : int | np .integer | ScalarVariable | TensorVariable ,
262262) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
263- """Convert slices to canonical form.
263+ """Convert indices or slices to canonical form.
264+
265+ Symbolic indices or symbolic slice attributes are supported.
266+ Symbolic slices are not.
264267
265268 Given a slice [start:stop:step] transform it into a canonical form
266269 that respects the conventions imposed by python and numpy.
@@ -281,35 +284,16 @@ def get_canonical_form_slice(
281284 """
282285 from pytensor .tensor import ge , lt , sign , switch
283286
284- # Convert the two symbolic slice types into a native slice
285- if isinstance (theslice , SliceConstant ):
286- theslice = cast (slice , theslice .data )
287- elif isinstance (theslice , Variable ) and isinstance (theslice .type , SliceType ):
288- theslice = slice (* theslice .owner .inputs )
289287 # Other non-slice types are the scalar indexing case
290- elif not isinstance (theslice , slice ):
291- sslice : int | np .integer | ScalarVariable | TensorVariable
292- if isinstance (theslice , ScalarVariable ):
293- sslice = theslice
294- elif isinstance (theslice , TensorVariable ) and theslice .ndim == 0 :
295- sslice = theslice
296- else :
297- # Only non-variable types should remain.
298- try :
299- vlit = as_index_literal (theslice )
300- if vlit is None :
301- raise ValueError ("Can't create canonical slice for `None` slices." )
302- # Slice returns from as_index are not expected, because the
303- # symbolic slice type inputs were already taken care of above.
304- assert isinstance (vlit , int | np .integer )
305- sslice = vlit
306- except NotScalarConstantError :
307- raise ValueError (f"Slice { theslice } is not a supported slice type." )
308- cano = switch (lt (sslice , 0 ), (sslice + length ), sslice )
309- return scalar_from_tensor (cano ), 1
288+ if not isinstance (theslice , slice ):
289+ if isinstance (theslice , int | np .integer | ScalarVariable ) or (
290+ isinstance (theslice , TensorVariable ) and theslice .ndim == 0
291+ ):
292+ cano = switch (lt (theslice , 0 ), (theslice + length ), theslice )
293+ return scalar_from_tensor (cano ), 1
294+ raise ValueError (f"Slice { theslice } is not a supported slice type." )
310295
311296 # At this point we have a slice object. Possibly with symbolic inputs.
312- assert isinstance (theslice , slice )
313297
314298 def analyze (x ):
315299 try :
0 commit comments