Skip to content

Commit 67c8df9

Browse files
Don't support symbolic slices in get_canonical_form_slice
1 parent 15918a6 commit 67c8df9

File tree

2 files changed

+28
-57
lines changed

2 files changed

+28
-57
lines changed

pytensor/tensor/subtensor.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def as_index_literal(
216216
)
217217

218218
if not isinstance(idx, Variable):
219-
raise NotScalarConstantError()
219+
raise TypeError(f"Not an index element: {idx}")
220220

221221
if isinstance(idx.type, NoneTypeT):
222222
return None
@@ -244,23 +244,26 @@ def get_idx_list(inputs, idx_list):
244244

245245
@overload
246246
def get_canonical_form_slice(
247-
theslice: slice | SliceConstant,
247+
theslice: slice,
248248
length: int | np.integer | ScalarVariable | TensorVariable,
249249
) -> tuple[slice, int | ScalarConstant]: ...
250250

251251

252252
@overload
253253
def get_canonical_form_slice(
254-
theslice: ScalarVariable | TensorVariable,
254+
theslice: int | np.integer | ScalarVariable | TensorVariable,
255255
length: int | np.integer | ScalarVariable | TensorVariable,
256256
) -> tuple[ScalarVariable, int]: ...
257257

258258

259259
def get_canonical_form_slice(
260-
theslice: slice | Variable,
260+
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
261261
length: int | np.integer | ScalarVariable | TensorVariable,
262262
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
263-
"""Convert slices to canonical form.
263+
"""Convert indices or slices to canonical form.
264+
265+
Symbolic indices or symbolic slice attributes are supported.
266+
Symbolic slices are not.
264267
265268
Given a slice [start:stop:step] transform it into a canonical form
266269
that respects the conventions imposed by python and numpy.
@@ -281,35 +284,16 @@ def get_canonical_form_slice(
281284
"""
282285
from pytensor.tensor import ge, lt, sign, switch
283286

284-
# Convert the two symbolic slice types into a native slice
285-
if isinstance(theslice, SliceConstant):
286-
theslice = cast(slice, theslice.data)
287-
elif isinstance(theslice, Variable) and isinstance(theslice.type, SliceType):
288-
theslice = slice(*theslice.owner.inputs)
289287
# Other non-slice types are the scalar indexing case
290-
elif not isinstance(theslice, slice):
291-
sslice: int | np.integer | ScalarVariable | TensorVariable
292-
if isinstance(theslice, ScalarVariable):
293-
sslice = theslice
294-
elif isinstance(theslice, TensorVariable) and theslice.ndim == 0:
295-
sslice = theslice
296-
else:
297-
# Only non-variable types should remain.
298-
try:
299-
vlit = as_index_literal(theslice)
300-
if vlit is None:
301-
raise ValueError("Can't create canonical slice for `None` slices.")
302-
# Slice returns from as_index are not expected, because the
303-
# symbolic slice type inputs were already taken care of above.
304-
assert isinstance(vlit, int | np.integer)
305-
sslice = vlit
306-
except NotScalarConstantError:
307-
raise ValueError(f"Slice {theslice} is not a supported slice type.")
308-
cano = switch(lt(sslice, 0), (sslice + length), sslice)
309-
return scalar_from_tensor(cano), 1
288+
if not isinstance(theslice, slice):
289+
if isinstance(theslice, int | np.integer | ScalarVariable) or (
290+
isinstance(theslice, TensorVariable) and theslice.ndim == 0
291+
):
292+
cano = switch(lt(theslice, 0), (theslice + length), theslice)
293+
return scalar_from_tensor(cano), 1
294+
raise ValueError(f"Slice {theslice} is not a supported slice type.")
310295

311296
# At this point we have a slice object. Possibly with symbolic inputs.
312-
assert isinstance(theslice, slice)
313297

314298
def analyze(x):
315299
try:

tests/tensor/test_subtensor.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
from pytensor.tensor.type_other import (
7272
NoneConst,
7373
SliceConstant,
74-
SliceType,
7574
as_symbolic_slice,
7675
make_slice,
7776
slicetype,
@@ -112,11 +111,19 @@ def test_as_index_literal():
112111

113112

114113
class TestGetCanonicalFormSlice:
115-
def test_none_raises(self):
116-
with pytest.raises(ValueError, match="Can't create"):
117-
get_canonical_form_slice(NoneConst, 5)
118-
with pytest.raises(ValueError, match="Can't create"):
119-
get_canonical_form_slice(None, 5)
114+
@pytest.mark.parametrize(
115+
"idx",
116+
[
117+
NoneConst,
118+
None,
119+
as_symbolic_slice(slice(3, 7, 2)),
120+
as_symbolic_slice(slice(3, int16(), 2)),
121+
vector(),
122+
],
123+
)
124+
def test_unsupported_inputs(self, idx):
125+
with pytest.raises(ValueError, match="not a supported slice"):
126+
get_canonical_form_slice(idx, 5)
120127

121128
def test_scalar_constant(self):
122129
a = as_scalar(0)
@@ -146,26 +153,6 @@ def test_symbolic_tensor(self):
146153
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
147154
assert res[1] == 1
148155

149-
def test_constant_slice(self):
150-
idx = as_symbolic_slice(slice(3, 7, 2))
151-
assert isinstance(idx, SliceConstant)
152-
res = get_canonical_form_slice(idx, 10)
153-
assert isinstance(res[0], slice)
154-
assert res[1] == 1
155-
156-
def test_symbolic_slice(self):
157-
idx = as_symbolic_slice(slice(3, int16(), 2))
158-
assert not isinstance(idx, slice)
159-
assert isinstance(idx.type, SliceType)
160-
res = get_canonical_form_slice(idx, 10)
161-
assert isinstance(res[0], slice)
162-
assert res[1] == 1
163-
164-
def test_unsupported_variable(self):
165-
idx = vector()
166-
with pytest.raises(ValueError, match="slice type"):
167-
get_canonical_form_slice(idx, 4)
168-
169156
def test_all_integer(self):
170157
res = get_canonical_form_slice(slice(1, 5, 2), 7)
171158
assert isinstance(res[0], slice)

0 commit comments

Comments
 (0)