diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 78f0f31489..7d0bdedf76 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable: @singledispatch -def _as_symbolic(x, **kwargs) -> Variable: +def _as_symbolic(x: Any, **kwargs) -> Variable: from pytensor.tensor import as_tensor_variable return as_tensor_variable(x, **kwargs) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 70e36ab60c..f71c591473 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1302,8 +1302,8 @@ def clone_node_and_cache( def clone_get_equiv( - inputs: Sequence[Variable], - outputs: Sequence[Variable], + inputs: Iterable[Variable], + outputs: Reversible[Variable], copy_inputs: bool = True, copy_orphans: bool = True, memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index ff310d3d4b..83c1ca0dd4 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,7 +1,7 @@ import warnings from collections.abc import Sequence from copy import copy -from typing import cast +from typing import Any, cast import numpy as np @@ -218,6 +218,7 @@ def _infer_shape( from pytensor.tensor.extra_ops import broadcast_shape_iter + supp_shape: tuple[Any] if self.ndim_supp == 0: supp_shape = () else: diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 9ddedb34b1..b77580e515 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -147,7 +147,9 @@ def explicit_expand_dims( return new_params -def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: +def compute_batch_shape( + params: Sequence[TensorVariable], ndims_params: Sequence[int] +) -> TensorVariable: params = explicit_expand_dims(params, ndims_params) batch_params = [ param[(..., *(0,) * core_ndim)] diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index b6fcd9fb21..66cab27dc6 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -144,14 +144,14 @@ def c_code_cache_version(self): _shape = Shape() -def shape(x: np.ndarray | Number | Variable) -> Variable: +def shape(x: np.ndarray | Number | Variable) -> TensorVariable: """Return the shape of `x`.""" if not isinstance(x, Variable): # The following is a type error in Python 3.9 but not 3.12. # Thus we need to ignore unused-ignore on 3.12. x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore] - return cast(Variable, _shape(x)) + return cast(TensorVariable, _shape(x)) @_get_vector_length.register(Shape) # type: ignore @@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]: # TODO: Why not use uint64? res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),) else: - res += (symbolic_shape[i],) # type: ignore + res += (symbolic_shape[i],) return res diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 2be0c7cd83..f0f5555499 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3,6 +3,7 @@ from collections.abc import Callable, Iterable from itertools import chain, groupby from textwrap import dedent +from typing import cast, overload import numpy as np @@ -19,13 +20,19 @@ from pytensor.link.c.params_type import ParamsType from pytensor.misc.safe_asarray import _asarray from pytensor.printing import Printer, pprint, set_precedence -from pytensor.scalar.basic import ScalarConstant -from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length +from pytensor.scalar.basic import ScalarConstant, ScalarVariable +from pytensor.tensor import ( + TensorLike, + _get_vector_length, + as_tensor_variable, + get_vector_length, +) from pytensor.tensor.basic import ( ScalarFromTensor, alloc, get_underlying_scalar_constant_value, nonzero, + scalar_from_tensor, ) from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle @@ -51,8 +58,14 @@ wscalar, zscalar, ) -from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice -from pytensor.tensor.variable import TensorVariable +from pytensor.tensor.type_other import ( + NoneConst, + NoneTypeT, + SliceConstant, + SliceType, + make_slice, +) +from pytensor.tensor.variable import TensorConstant, TensorVariable _logger = logging.getLogger("pytensor.tensor.subtensor") @@ -134,7 +147,7 @@ def convert_indices(indices, entry): def as_index_constant( - a: slice | int | np.integer | Variable | None, + a: slice | int | np.integer | Variable | None | TensorLike, ) -> Variable | slice | None: r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments. @@ -150,15 +163,41 @@ def as_index_constant( ) elif isinstance(a, int | np.integer): return ps.ScalarConstant(ps.int64, a) - elif not isinstance(a, Variable): - return as_tensor_variable(a) - else: + elif isinstance(a, Variable): return a + return as_tensor_variable(a) + + +@overload +def as_index_literal(idx: int | np.integer) -> int | np.integer: ... + + +@overload +def as_index_literal(idx: None) -> None: ... + + +@overload +def as_index_literal(idx: slice | SliceConstant) -> slice: ... + + +@overload +def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ... + + +@overload +def as_index_literal(idx: Variable): ... def as_index_literal( - idx: Variable | slice | None, -) -> int | slice | None: + idx: None + | int + | np.integer + | slice + | SliceConstant + | ScalarConstant + | TensorConstant + | Variable, +) -> int | np.integer | slice | None: """Convert a symbolic index element to its Python equivalent. This is like the inverse of `as_index_constant` @@ -167,22 +206,8 @@ def as_index_literal( ------ NotScalarConstantError """ - if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT): - return np.newaxis - - if isinstance(idx, Constant): - return idx.data.item() if isinstance(idx, np.ndarray) else idx.data - - if isinstance(idx, Variable): - if ( - isinstance(idx.type, ps.ScalarType) - and idx.owner - and isinstance(idx.owner.op, ScalarFromTensor) - ): - return as_index_literal(idx.owner.inputs[0]) - - if isinstance(idx.type, SliceType): - idx = slice(*idx.owner.inputs) + if idx is None or isinstance(idx, int | np.integer): + return idx if isinstance(idx, slice): return slice( @@ -191,6 +216,33 @@ def as_index_literal( as_index_literal(idx.step), ) + if not isinstance(idx, Variable): + raise TypeError(f"Not an index element: {idx}") + + if isinstance(idx.type, NoneTypeT): + return None + + if isinstance(idx, ScalarConstant): + return cast(int, idx.data) + + if ( + isinstance(idx.type, ps.ScalarType) + and idx.owner + and isinstance(idx.owner.op, ScalarFromTensor) + ): + return cast(int | np.integer, as_index_literal(idx.owner.inputs[0])) + + if isinstance(idx, TensorConstant): + return cast(int, idx.data.item()) + + if isinstance(idx, SliceConstant): + return cast(slice, idx.data) + + if isinstance(idx.type, SliceType): + assert idx.owner is not None + return slice(*map(as_index_literal, idx.owner.inputs)) + + # Other kinds of variables are not supported raise NotScalarConstantError() @@ -198,10 +250,30 @@ def get_idx_list(inputs, idx_list): return indices_from_subtensor(inputs[1:], idx_list) +@overload +def get_canonical_form_slice( + theslice: slice, + length: int | np.integer | ScalarVariable | TensorVariable, +) -> tuple[slice, int | ScalarConstant]: ... + + +@overload +def get_canonical_form_slice( + theslice: int | np.integer | ScalarVariable | TensorVariable, + length: int | np.integer | ScalarVariable | TensorVariable, +) -> tuple[ScalarVariable, int]: ... + + def get_canonical_form_slice( - theslice: slice | Variable, length: Variable -) -> tuple[Variable, int]: - """Convert slices to canonical form. + theslice: slice | int | np.integer | ScalarVariable | TensorVariable, + length: int | np.integer | ScalarVariable | TensorVariable, +) -> tuple[slice | ScalarVariable, int | ScalarConstant]: + """Convert indices or slices to canonical form. + + Scalar integer indices or python Slices with Scalar/None attributes + used in basic Subtensor Ops are supported. + Symbolic slices (of SliceType) or vector indices + used in advanced Subtensor Ops are not supported. Given a slice [start:stop:step] transform it into a canonical form that respects the conventions imposed by python and numpy. @@ -210,18 +282,28 @@ def get_canonical_form_slice( in which 0 <= start <= stop <= length and step > 0, and a flag which says if the resulting set of numbers needs to be reversed or not. + Given a scalar index `idx` that may or not be negative, convert it to + a certainly positive form `idx if idx >= 0 else length + idx`. + + Returns + ------- + slc + Canonical form slice or scalar variable. + direction + Direction to iterate the resulting elements in. (-1 or 1). May be symbolic. """ from pytensor.tensor import ge, lt, sign, switch + # Other non-slice types are the scalar indexing case if not isinstance(theslice, slice): - try: - value = as_index_literal(theslice) - except NotScalarConstantError: - value = theslice - - value = switch(lt(value, 0), (value + length), value) + if isinstance(theslice, int | np.integer | ScalarVariable) or ( + isinstance(theslice, TensorVariable) and theslice.ndim == 0 + ): + cano = switch(lt(theslice, 0), (theslice + length), theslice) + return scalar_from_tensor(cano), 1 + raise ValueError(f"Slice {theslice} is not a supported slice type.") - return value, 1 + # At this point we have a slice object. Possibly with symbolic inputs. def analyze(x): try: @@ -243,6 +325,7 @@ def analyze(x): and is_step_constant and is_length_constant ): + assert isinstance(length, int) _start, _stop, _step = slice(start, stop, step).indices(length) if _start <= _stop and _step >= 1: return slice(_start, _stop, _step), 1 @@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"): return a[full_indices] -@_get_vector_length.register(Subtensor) +@_get_vector_length.register(Subtensor) # type: ignore def _get_vector_length_Subtensor(op, var): # If we take a slice, we know how many elements it will result in # TODO: We can cover more `*Subtensor` cases. diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 730ae9b07b..952c22982a 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -138,7 +138,7 @@ def clone( shape = self.shape return type(self)(dtype, shape, name=self.name) - def filter(self, data, strict=False, allow_downcast=None): + def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: """Convert `data` to something which can be associated to a `TensorVariable`. This function is not meant to be called in user code. It is for diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index d73c19752b..e97b008e54 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -25,7 +25,6 @@ pytensor/tensor/random/op.py pytensor/tensor/random/utils.py pytensor/tensor/rewriting/basic.py pytensor/tensor/slinalg.py -pytensor/tensor/subtensor.py pytensor/tensor/type.py pytensor/tensor/type_other.py pytensor/tensor/variable.py diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index ac2bf58446..f4ba58e26a 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -16,8 +16,8 @@ from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint -from pytensor.scalar.basic import as_scalar -from pytensor.tensor import get_vector_length, vectorize +from pytensor.scalar.basic import as_scalar, int16 +from pytensor.tensor import as_tensor, get_vector_length, vectorize from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import exp, isinf @@ -69,7 +69,13 @@ tensor5, vector, ) -from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype +from pytensor.tensor.type_other import ( + NoneConst, + SliceConstant, + as_symbolic_slice, + make_slice, + slicetype, +) from tests import unittest_tools as utt from tests.tensor.utils import inplace_func, integers_ranged, random @@ -106,11 +112,51 @@ def test_as_index_literal(): class TestGetCanonicalFormSlice: + @pytest.mark.parametrize( + "idx", + [ + NoneConst, + None, + as_symbolic_slice(slice(3, 7, 2)), + as_symbolic_slice(slice(3, int16(), 2)), + vector(), + ], + ) + def test_unsupported_inputs(self, idx): + with pytest.raises(ValueError, match="not a supported slice"): + get_canonical_form_slice(idx, 5) + def test_scalar_constant(self): a = as_scalar(0) length = lscalar() res = get_canonical_form_slice(a, length) - assert res[0].owner.op == ptb.switch + assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) + assert res[1] == 1 + + def test_tensor_constant(self): + a = as_tensor(0) + length = lscalar() + res = get_canonical_form_slice(a, length) + assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) + assert res[1] == 1 + + def test_symbolic_scalar(self): + a = int16() + length = lscalar() + res = get_canonical_form_slice(a, length) + assert res[0].owner.op, ptb.switch + assert res[1] == 1 + + def test_symbolic_tensor(self): + a = lscalar() + length = lscalar() + res = get_canonical_form_slice(a, length) + assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) + assert res[1] == 1 + + def test_all_integer(self): + res = get_canonical_form_slice(slice(1, 5, 2), 7) + assert isinstance(res[0], slice) assert res[1] == 1 def test_all_symbolic(self):