Skip to content

Commit 72196a3

Browse files
committed
Handle non-constant NoneTypeT variables
1 parent 049046d commit 72196a3

File tree

7 files changed

+64
-23
lines changed

7 files changed

+64
-23
lines changed

pytensor/tensor/random/op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def make_node(self, rng, size, *dist_params):
385385
dist_params = explicit_expand_dims(
386386
dist_params,
387387
self.ndims_params,
388-
size_length=None if NoneConst.equals(size) else get_vector_length(size),
388+
size_length=None
389+
if isinstance(size.type, NoneTypeT)
390+
else get_vector_length(size),
389391
)
390392

391393
inputs = (rng, size, *dist_params)

pytensor/tensor/random/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
dfs_rewriter,
1010
node_rewriter,
1111
)
12-
from pytensor.tensor import NoneConst, TensorVariable
12+
from pytensor.tensor import TensorVariable
1313
from pytensor.tensor.basic import constant
1414
from pytensor.tensor.elemwise import DimShuffle
1515
from pytensor.tensor.extra_ops import broadcast_to
@@ -244,7 +244,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
244244
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
245245
# and make use of the dimshuffle lift rewrite
246246
if any(
247-
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
247+
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
248248
for idx in indices
249249
):
250250
return False
@@ -267,7 +267,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
267267
for idx in supp_indices:
268268
if not (
269269
isinstance(idx.type, SliceType)
270-
and all(NoneConst.equals(i) for i in idx.owner.inputs)
270+
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
271271
):
272272
return False
273273
n_discarded_idxs = len(supp_indices)

pytensor/tensor/random/utils.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import numpy as np
88

99
from pytensor.compile.sharedvalue import shared
10-
from pytensor.graph.basic import Constant, Variable
10+
from pytensor.graph.basic import Variable
1111
from pytensor.scalar import ScalarVariable
1212
from pytensor.tensor import NoneConst, get_vector_length
1313
from pytensor.tensor.basic import as_tensor_variable, cast
1414
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
1717
from pytensor.tensor.type import int_dtypes
18+
from pytensor.tensor.type_other import NoneTypeT
1819
from pytensor.tensor.utils import faster_broadcast_to
1920
from pytensor.tensor.variable import TensorVariable
2021

@@ -178,24 +179,26 @@ def normalize_size_param(
178179
shape: int | np.ndarray | Variable | Sequence | None,
179180
) -> Variable:
180181
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
181-
if shape is None or NoneConst.equals(shape):
182+
if shape is None:
182183
return NoneConst
183-
elif isinstance(shape, int):
184+
if isinstance(shape, Variable) and isinstance(shape.type, NoneTypeT):
185+
return shape
186+
187+
if isinstance(shape, int):
184188
shape = as_tensor_variable([shape], ndim=1)
185-
elif not isinstance(shape, np.ndarray | Variable | Sequence):
186-
raise TypeError(
187-
"Parameter size must be None, an integer, or a sequence with integers."
188-
)
189189
else:
190+
if not isinstance(shape, Sequence | Variable | np.ndarray):
191+
raise TypeError(
192+
"Parameter size must be None, an integer, or a sequence with integers."
193+
)
190194
shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
191195

192-
if not isinstance(shape, Constant):
196+
if shape.type.shape == (None,):
193197
# This should help ensure that the length of non-constant `size`s
194-
# will be available after certain types of cloning (e.g. the kind
195-
# `Scan` performs)
198+
# will be available after certain types of cloning (e.g. the kind `Scan` performs)
196199
shape = specify_shape(shape, (get_vector_length(shape),))
197200

198-
assert not any(s is None for s in shape.type.shape)
201+
assert shape.type.shape != (None,)
199202
assert shape.dtype in int_dtypes
200203

201204
return shape

pytensor/tensor/rewriting/shape.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from pytensor.tensor.subtensor import Subtensor, get_idx_list
4949
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
50-
from pytensor.tensor.type_other import NoneConst, NoneTypeT
50+
from pytensor.tensor.type_other import NoneTypeT
5151
from pytensor.tensor.variable import TensorVariable
5252

5353

@@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
11371137

11381138
inner_obj, *shape = obj.owner.inputs
11391139
for dim, sh in enumerate(node.inputs[1:]):
1140-
if not NoneConst.equals(sh):
1140+
if not isinstance(sh.type, NoneTypeT):
11411141
shape[dim] = sh
11421142

11431143
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
@@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
11831183

11841184
# Replace `NoneConst` by `shape_i`
11851185
for i, sh in enumerate(shape):
1186-
if NoneConst.equals(sh):
1186+
if isinstance(sh.type, NoneTypeT):
11871187
shape[i] = x.shape[i]
11881188

11891189
return [stack(shape).astype(np.int64)]
@@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node):
12191219
for i, (dim, bcast) in enumerate(
12201220
zip(shape, out_broadcastable, strict=True)
12211221
)
1222-
if (not bcast and not NoneConst.equals(dim))
1222+
if (not bcast and not isinstance(dim.type, NoneTypeT))
12231223
}
12241224
new_elem_inps = elem_inps.copy()
12251225
for i, elem_inp in enumerate(elem_inps):

pytensor/tensor/shape.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def make_node(self, x, *shape):
408408

409409
shape = tuple(
410410
NoneConst
411-
if (s is None or NoneConst.equals(s))
411+
if (
412+
s is None or (isinstance(s, Variable) and isinstance(s.type, NoneTypeT))
413+
)
412414
else ptb.as_tensor_variable(s, ndim=0)
413415
for s in shape
414416
)
@@ -506,7 +508,7 @@ def c_code(self, node, name, i_names, o_names, sub):
506508
for i, (shp_name, shp) in enumerate(
507509
zip(shape_names, node.inputs[1:], strict=True)
508510
):
509-
if NoneConst.equals(shp):
511+
if isinstance(shp.type, NoneTypeT):
510512
continue
511513
code += dedent(
512514
f"""
@@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape):
594596
if any(
595597
as_tensor_variable(dim).type.ndim != 0
596598
for dim in shape
597-
if not (NoneConst.equals(dim) or dim is None)
599+
if not (
600+
(isinstance(dim, Variable) and isinstance(dim.type, NoneTypeT))
601+
or dim is None
602+
)
598603
):
599604
raise NotImplementedError(
600605
"It is not possible to vectorize the shape argument of SpecifyShape"

tests/tensor/random/test_op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.tensor.random.op import RandomVariable, default_rng
1212
from pytensor.tensor.shape import specify_shape
1313
from pytensor.tensor.type import iscalar, tensor
14+
from pytensor.tensor.type_other import none_type_t
1415

1516

1617
@pytest.fixture(scope="function", autouse=False)
@@ -317,3 +318,12 @@ def test_size_none_vs_empty():
317318
ValueError, match="Size length is incompatible with batched dimensions"
318319
):
319320
rv([0], [1], size=())
321+
322+
323+
def test_non_constant_none_size():
324+
# Regression test for https://github.com/pymc-devs/pymc/issues/7901#issuecomment-3528479876
325+
loc = pt.vector("loc")
326+
size = none_type_t("none_size")
327+
328+
rv = normal(loc, size=size)
329+
rv.eval({loc: np.arange(5, dtype="float64"), size: None}, mode="FAST_COMPILE")

tests/tensor/random/test_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from pytensor.tensor.random.utils import (
88
RandomStream,
99
broadcast_params,
10+
normalize_size_param,
1011
supp_shape_from_ref_param_shape,
1112
)
12-
from pytensor.tensor.type import matrix, tensor
13+
from pytensor.tensor.type import TensorType, matrix, tensor
14+
from pytensor.tensor.type_other import NoneTypeT, none_type_t
1315
from tests import unittest_tools as utt
1416

1517

@@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape():
327329
ref_param_idx=1,
328330
)
329331
assert res == (3, 4)
332+
333+
334+
def test_normalize_size_param():
335+
assert normalize_size_param(None).type == NoneTypeT()
336+
337+
sym_none_size = none_type_t()
338+
assert normalize_size_param(sym_none_size) is sym_none_size
339+
340+
empty_size = normalize_size_param(())
341+
assert empty_size.type == TensorType(dtype="int64", shape=(0,))
342+
343+
int_size = normalize_size_param(5)
344+
assert int_size.type == TensorType(dtype="int64", shape=(1,))
345+
346+
seq_int_size = normalize_size_param((5, 3, 4))
347+
assert seq_int_size.type == TensorType(dtype="int64", shape=(3,))
348+
349+
sym_tensor_size = tensor(shape=(3,), dtype="int64")
350+
assert normalize_size_param(sym_tensor_size) is sym_tensor_size

0 commit comments

Comments
 (0)