Skip to content

Commit e016dc7

Browse files
Further type refactoring of slice helper functions
1 parent 6200b05 commit e016dc7

File tree

2 files changed

+126
-50
lines changed

2 files changed

+126
-50
lines changed

pytensor/tensor/subtensor.py

Lines changed: 87 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
as_tensor_variable,
2828
get_vector_length,
2929
)
30-
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
30+
from pytensor.tensor.basic import (
31+
alloc,
32+
get_underlying_scalar_constant_value,
33+
nonzero,
34+
scalar_from_tensor,
35+
)
3136
from pytensor.tensor.blockwise import vectorize_node_fallback
3237
from pytensor.tensor.elemwise import DimShuffle
3338
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
@@ -52,8 +57,14 @@
5257
wscalar,
5358
zscalar,
5459
)
55-
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
56-
from pytensor.tensor.variable import TensorVariable
60+
from pytensor.tensor.type_other import (
61+
NoneConst,
62+
NoneTypeT,
63+
SliceConstant,
64+
SliceType,
65+
make_slice,
66+
)
67+
from pytensor.tensor.variable import TensorConstant, TensorVariable
5768

5869

5970
_logger = logging.getLogger("pytensor.tensor.subtensor")
@@ -165,19 +176,26 @@ def as_index_literal(idx: None) -> None: ...
165176

166177

167178
@overload
168-
def as_index_literal(idx: slice) -> slice: ...
179+
def as_index_literal(idx: slice | SliceConstant) -> slice: ...
169180

170181

171182
@overload
172-
def as_index_literal(idx: Constant) -> int | np.integer | None: ...
183+
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...
173184

174185

175186
@overload
176-
def as_index_literal(idx: Variable) -> int | np.integer | slice | None: ...
187+
def as_index_literal(idx: Variable): ...
177188

178189

179190
def as_index_literal(
180-
idx: int | np.integer | Variable | slice | None,
191+
idx: None
192+
| int
193+
| np.integer
194+
| slice
195+
| SliceConstant
196+
| ScalarConstant
197+
| TensorConstant
198+
| Variable,
181199
) -> int | np.integer | slice | None:
182200
"""Convert a symbolic index element to its Python equivalent.
183201
@@ -187,40 +205,61 @@ def as_index_literal(
187205
------
188206
NotScalarConstantError
189207
"""
190-
if isinstance(idx, int | np.integer):
208+
if idx is None or isinstance(idx, int | np.integer):
191209
return idx
192210

193-
# NOTE: np.newaxis is None
194-
idxtype = getattr(idx, "type", None)
195-
if idx is None or isinstance(idxtype, NoneTypeT):
196-
return None
197-
198-
if isinstance(idx, Constant):
199-
data = idx.data
200-
if isinstance(data, np.ndarray):
201-
return cast(int, data.item())
202-
return cast(int | None, data)
203-
204-
if isinstance(idx, Variable) and isinstance(idxtype, SliceType):
205-
idx = slice(*idx.owner.inputs)
206-
207211
if isinstance(idx, slice):
208212
return slice(
209213
as_index_literal(idx.start),
210214
as_index_literal(idx.stop),
211215
as_index_literal(idx.step),
212216
)
213217

218+
if not isinstance(idx, Variable):
219+
raise NotScalarConstantError()
220+
221+
if isinstance(idx.type, NoneTypeT):
222+
return None
223+
224+
if isinstance(idx, ScalarConstant):
225+
return cast(int | None, idx.data)
226+
227+
if isinstance(idx, TensorConstant):
228+
return cast(int, idx.data.item())
229+
230+
if isinstance(idx, SliceConstant):
231+
return cast(slice, idx.data)
232+
233+
if isinstance(idx.type, SliceType):
234+
assert idx.owner is not None
235+
return slice(*map(as_index_literal, idx.owner.inputs))
236+
237+
# Other kinds of variables are not supported
214238
raise NotScalarConstantError()
215239

216240

217241
def get_idx_list(inputs, idx_list):
218242
return indices_from_subtensor(inputs[1:], idx_list)
219243

220244

245+
@overload
246+
def get_canonical_form_slice(
247+
theslice: slice | SliceConstant,
248+
length: int | np.integer | ScalarVariable | TensorVariable,
249+
) -> tuple[slice, int | ScalarConstant]: ...
250+
251+
252+
@overload
253+
def get_canonical_form_slice(
254+
theslice: ScalarVariable | TensorVariable,
255+
length: int | np.integer | ScalarVariable | TensorVariable,
256+
) -> tuple[slice | ScalarVariable, int]: ...
257+
258+
221259
def get_canonical_form_slice(
222-
theslice: slice | Variable, length: ScalarVariable
223-
) -> tuple[slice | Variable, int | ScalarConstant]:
260+
theslice: slice | Variable,
261+
length: int | np.integer | ScalarVariable | TensorVariable,
262+
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
224263
"""Convert slices to canonical form.
225264
226265
Given a slice [start:stop:step] transform it into a canonical form
@@ -230,36 +269,46 @@ def get_canonical_form_slice(
230269
in which 0 <= start <= stop <= length and step > 0, and a flag which says
231270
if the resulting set of numbers needs to be reversed or not.
232271
272+
Given a scalar index `idx` that may or not be negative, convert it to
273+
a certainly positive form `idx if idx >= 0 else length + idx`.
274+
233275
Returns
234276
-------
235277
slc
236-
Slice or scalar variable.
278+
Canonical form slice or scalar variable.
237279
direction
238-
Direction to iterate the resulting elements in. (-1 or 1).
280+
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
239281
"""
240282
from pytensor.tensor import ge, lt, sign, switch
241283

242-
if not isinstance(theslice, slice):
243-
# Try to extract a scalar slice and return it already.
244-
sslice: int | np.integer | ScalarVariable | None = None
284+
# Convert the two symbolic slice types into a native slice
285+
if isinstance(theslice, SliceConstant):
286+
theslice = cast(slice, theslice.data)
287+
elif isinstance(theslice, Variable) and isinstance(theslice.type, SliceType):
288+
theslice = slice(*theslice.owner.inputs)
289+
# Other non-slice types are the scalar indexing case
290+
elif not isinstance(theslice, slice):
291+
sslice: int | np.integer | ScalarVariable | TensorVariable
245292
if isinstance(theslice, ScalarVariable):
246293
sslice = theslice
294+
elif isinstance(theslice, TensorVariable) and theslice.ndim == 0:
295+
sslice = theslice
247296
else:
297+
# Only non-variable types should remain.
248298
try:
249299
vlit = as_index_literal(theslice)
250300
if vlit is None:
251-
return slice(0, length, 1), 1
252-
if isinstance(vlit, slice):
253-
# Input was a SliceType variable, from which a slice was extracted.
254-
theslice = vlit
255-
else:
256-
sslice = vlit
301+
raise ValueError("Can't create canonical slice for `None` slices.")
302+
# Slice returns from as_index are not expected, because the
303+
# symbolic slice type inputs were already taken care of above.
304+
assert isinstance(vlit, int | np.integer)
305+
sslice = vlit
257306
except NotScalarConstantError:
258307
raise ValueError(f"Slice {theslice} is not a supported slice type.")
259-
if isinstance(sslice, int | np.integer | ScalarVariable):
260-
return switch(lt(sslice, 0), (sslice + length), sslice), 1
308+
cano = switch(lt(sslice, 0), (sslice + length), sslice)
309+
return scalar_from_tensor(cano), 1
261310

262-
# At this point we have a slice object. Possibly symbolic.
311+
# At this point we have a slice object. Possibly with symbolic inputs.
263312
assert isinstance(theslice, slice)
264313

265314
def analyze(x):

tests/tensor/test_subtensor.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.graph.op import get_test_value
1717
from pytensor.graph.rewriting.utils import is_same_graph
1818
from pytensor.printing import pprint
19-
from pytensor.scalar.basic import as_scalar
19+
from pytensor.scalar.basic import as_scalar, int16
2020
from pytensor.tensor import as_tensor, get_vector_length, vectorize
2121
from pytensor.tensor.blockwise import Blockwise
2222
from pytensor.tensor.elemwise import DimShuffle
@@ -68,7 +68,14 @@
6868
tensor5,
6969
vector,
7070
)
71-
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
71+
from pytensor.tensor.type_other import (
72+
NoneConst,
73+
SliceConstant,
74+
SliceType,
75+
as_symbolic_slice,
76+
make_slice,
77+
slicetype,
78+
)
7279
from tests import unittest_tools as utt
7380
from tests.tensor.utils import inplace_func, integers_ranged, random
7481

@@ -105,31 +112,51 @@ def test_as_index_literal():
105112

106113

107114
class TestGetCanonicalFormSlice:
108-
def test_none_constant(self):
109-
slc, rev = get_canonical_form_slice(NoneConst, 5)
110-
assert isinstance(slc, slice)
111-
assert slc.start == 0
112-
assert slc.stop == 5
113-
assert slc.step == 1
114-
assert rev == 1
115+
def test_none_raises(self):
116+
with pytest.raises(ValueError, match="Can't create"):
117+
get_canonical_form_slice(NoneConst, 5)
118+
with pytest.raises(ValueError, match="Can't create"):
119+
get_canonical_form_slice(None, 5)
115120

116121
def test_scalar_constant(self):
117122
a = as_scalar(0)
118123
length = lscalar()
119124
res = get_canonical_form_slice(a, length)
120-
assert res[0].owner.op == ptb.switch
125+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
121126
assert res[1] == 1
122127

123128
def test_tensor_constant(self):
124129
a = as_tensor(0)
125130
length = lscalar()
126131
res = get_canonical_form_slice(a, length)
127-
assert res[0].owner.op == ptb.switch
132+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
133+
assert res[1] == 1
134+
135+
def test_symbolic_scalar(self):
136+
a = int16()
137+
length = lscalar()
138+
res = get_canonical_form_slice(a, length)
139+
assert res[0].owner.op, ptb.switch
140+
assert res[1] == 1
141+
142+
def test_symbolic_tensor(self):
143+
a = lscalar()
144+
length = lscalar()
145+
res = get_canonical_form_slice(a, length)
146+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
147+
assert res[1] == 1
148+
149+
def test_constant_slice(self):
150+
idx = as_symbolic_slice(slice(3, 7, 2))
151+
assert isinstance(idx, SliceConstant)
152+
res = get_canonical_form_slice(idx, 10)
153+
assert isinstance(res[0], slice)
128154
assert res[1] == 1
129155

130156
def test_symbolic_slice(self):
131-
idx = make_slice(slice(3, 7, 2))
157+
idx = as_symbolic_slice(slice(3, int16(), 2))
132158
assert not isinstance(idx, slice)
159+
assert isinstance(idx.type, SliceType)
133160
res = get_canonical_form_slice(idx, 10)
134161
assert isinstance(res[0], slice)
135162
assert res[1] == 1

0 commit comments

Comments
 (0)