2727 as_tensor_variable ,
2828 get_vector_length ,
2929)
30- from pytensor .tensor .basic import alloc , get_underlying_scalar_constant_value , nonzero
30+ from pytensor .tensor .basic import (
31+ alloc ,
32+ get_underlying_scalar_constant_value ,
33+ nonzero ,
34+ scalar_from_tensor ,
35+ )
3136from pytensor .tensor .blockwise import vectorize_node_fallback
3237from pytensor .tensor .elemwise import DimShuffle
3338from pytensor .tensor .exceptions import AdvancedIndexingError , NotScalarConstantError
5257 wscalar ,
5358 zscalar ,
5459)
55- from pytensor .tensor .type_other import NoneConst , NoneTypeT , SliceType , make_slice
56- from pytensor .tensor .variable import TensorVariable
60+ from pytensor .tensor .type_other import (
61+ NoneConst ,
62+ NoneTypeT ,
63+ SliceConstant ,
64+ SliceType ,
65+ make_slice ,
66+ )
67+ from pytensor .tensor .variable import TensorConstant , TensorVariable
5768
5869
5970_logger = logging .getLogger ("pytensor.tensor.subtensor" )
@@ -165,19 +176,26 @@ def as_index_literal(idx: None) -> None: ...
165176
166177
167178@overload
168- def as_index_literal (idx : slice ) -> slice : ...
179+ def as_index_literal (idx : slice | SliceConstant ) -> slice : ...
169180
170181
171182@overload
172- def as_index_literal (idx : Constant ) -> int | np .integer | None : ...
183+ def as_index_literal (idx : ScalarConstant | TensorConstant ) -> int | np .integer : ...
173184
174185
175186@overload
176- def as_index_literal (idx : Variable ) -> int | np . integer | slice | None : ...
187+ def as_index_literal (idx : Variable ): ...
177188
178189
179190def as_index_literal (
180- idx : int | np .integer | Variable | slice | None ,
191+ idx : None
192+ | int
193+ | np .integer
194+ | slice
195+ | SliceConstant
196+ | ScalarConstant
197+ | TensorConstant
198+ | Variable ,
181199) -> int | np .integer | slice | None :
182200 """Convert a symbolic index element to its Python equivalent.
183201
@@ -187,40 +205,61 @@ def as_index_literal(
187205 ------
188206 NotScalarConstantError
189207 """
190- if isinstance (idx , int | np .integer ):
208+ if idx is None or isinstance (idx , int | np .integer ):
191209 return idx
192210
193- # NOTE: np.newaxis is None
194- idxtype = getattr (idx , "type" , None )
195- if idx is None or isinstance (idxtype , NoneTypeT ):
196- return None
197-
198- if isinstance (idx , Constant ):
199- data = idx .data
200- if isinstance (data , np .ndarray ):
201- return cast (int , data .item ())
202- return cast (int | None , data )
203-
204- if isinstance (idx , Variable ) and isinstance (idxtype , SliceType ):
205- idx = slice (* idx .owner .inputs )
206-
207211 if isinstance (idx , slice ):
208212 return slice (
209213 as_index_literal (idx .start ),
210214 as_index_literal (idx .stop ),
211215 as_index_literal (idx .step ),
212216 )
213217
218+ if not isinstance (idx , Variable ):
219+ raise NotScalarConstantError ()
220+
221+ if isinstance (idx .type , NoneTypeT ):
222+ return None
223+
224+ if isinstance (idx , ScalarConstant ):
225+ return cast (int | None , idx .data )
226+
227+ if isinstance (idx , TensorConstant ):
228+ return cast (int , idx .data .item ())
229+
230+ if isinstance (idx , SliceConstant ):
231+ return cast (slice , idx .data )
232+
233+ if isinstance (idx .type , SliceType ):
234+ assert idx .owner is not None
235+ return slice (* map (as_index_literal , idx .owner .inputs ))
236+
237+ # Other kinds of variables are not supported
214238 raise NotScalarConstantError ()
215239
216240
217241def get_idx_list (inputs , idx_list ):
218242 return indices_from_subtensor (inputs [1 :], idx_list )
219243
220244
245+ @overload
246+ def get_canonical_form_slice (
247+ theslice : slice | SliceConstant ,
248+ length : int | np .integer | ScalarVariable | TensorVariable ,
249+ ) -> tuple [slice , int | ScalarConstant ]: ...
250+
251+
252+ @overload
253+ def get_canonical_form_slice (
254+ theslice : ScalarVariable | TensorVariable ,
255+ length : int | np .integer | ScalarVariable | TensorVariable ,
256+ ) -> tuple [slice | ScalarVariable , int ]: ...
257+
258+
221259def get_canonical_form_slice (
222- theslice : slice | Variable , length : ScalarVariable
223- ) -> tuple [slice | Variable , int | ScalarConstant ]:
260+ theslice : slice | Variable ,
261+ length : int | np .integer | ScalarVariable | TensorVariable ,
262+ ) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
224263 """Convert slices to canonical form.
225264
226265 Given a slice [start:stop:step] transform it into a canonical form
@@ -230,36 +269,46 @@ def get_canonical_form_slice(
230269 in which 0 <= start <= stop <= length and step > 0, and a flag which says
231270 if the resulting set of numbers needs to be reversed or not.
232271
272+ Given a scalar index `idx` that may or not be negative, convert it to
273+ a certainly positive form `idx if idx >= 0 else length + idx`.
274+
233275 Returns
234276 -------
235277 slc
236- Slice or scalar variable.
278+ Canonical form slice or scalar variable.
237279 direction
238- Direction to iterate the resulting elements in. (-1 or 1).
280+ Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
239281 """
240282 from pytensor .tensor import ge , lt , sign , switch
241283
242- if not isinstance (theslice , slice ):
243- # Try to extract a scalar slice and return it already.
244- sslice : int | np .integer | ScalarVariable | None = None
284+ # Convert the two symbolic slice types into a native slice
285+ if isinstance (theslice , SliceConstant ):
286+ theslice = cast (slice , theslice .data )
287+ elif isinstance (theslice , Variable ) and isinstance (theslice .type , SliceType ):
288+ theslice = slice (* theslice .owner .inputs )
289+ # Other non-slice types are the scalar indexing case
290+ elif not isinstance (theslice , slice ):
291+ sslice : int | np .integer | ScalarVariable | TensorVariable
245292 if isinstance (theslice , ScalarVariable ):
246293 sslice = theslice
294+ elif isinstance (theslice , TensorVariable ) and theslice .ndim == 0 :
295+ sslice = theslice
247296 else :
297+ # Only non-variable types should remain.
248298 try :
249299 vlit = as_index_literal (theslice )
250300 if vlit is None :
251- return slice (0 , length , 1 ), 1
252- if isinstance (vlit , slice ):
253- # Input was a SliceType variable, from which a slice was extracted.
254- theslice = vlit
255- else :
256- sslice = vlit
301+ raise ValueError ("Can't create canonical slice for `None` slices." )
302+ # Slice returns from as_index are not expected, because the
303+ # symbolic slice type inputs were already taken care of above.
304+ assert isinstance (vlit , int | np .integer )
305+ sslice = vlit
257306 except NotScalarConstantError :
258307 raise ValueError (f"Slice { theslice } is not a supported slice type." )
259- if isinstance ( sslice , int | np . integer | ScalarVariable ):
260- return switch ( lt ( sslice , 0 ), ( sslice + length ), sslice ), 1
308+ cano = switch ( lt ( sslice , 0 ), ( sslice + length ), sslice )
309+ return scalar_from_tensor ( cano ), 1
261310
262- # At this point we have a slice object. Possibly symbolic.
311+ # At this point we have a slice object. Possibly with symbolic inputs .
263312 assert isinstance (theslice , slice )
264313
265314 def analyze (x ):
0 commit comments