|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | from pytensor.compile.sharedvalue import shared |
10 | | -from pytensor.graph.basic import Constant, Variable |
| 10 | +from pytensor.graph.basic import Variable |
11 | 11 | from pytensor.scalar import ScalarVariable |
12 | 12 | from pytensor.tensor import NoneConst, get_vector_length |
13 | 13 | from pytensor.tensor.basic import as_tensor_variable, cast |
14 | 14 | from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to |
15 | 15 | from pytensor.tensor.math import maximum |
16 | 16 | from pytensor.tensor.shape import shape_padleft, specify_shape |
17 | 17 | from pytensor.tensor.type import int_dtypes |
| 18 | +from pytensor.tensor.type_other import NoneTypeT |
18 | 19 | from pytensor.tensor.utils import faster_broadcast_to |
19 | 20 | from pytensor.tensor.variable import TensorVariable |
20 | 21 |
|
@@ -178,24 +179,26 @@ def normalize_size_param( |
178 | 179 | shape: int | np.ndarray | Variable | Sequence | None, |
179 | 180 | ) -> Variable: |
180 | 181 | """Create an PyTensor value for a ``RandomVariable`` ``size`` parameter.""" |
181 | | - if shape is None or NoneConst.equals(shape): |
| 182 | + if shape is None: |
182 | 183 | return NoneConst |
183 | | - elif isinstance(shape, int): |
| 184 | + if isinstance(shape, Variable) and isinstance(shape.type, NoneTypeT): |
| 185 | + return shape |
| 186 | + |
| 187 | + if isinstance(shape, int): |
184 | 188 | shape = as_tensor_variable([shape], ndim=1) |
185 | | - elif not isinstance(shape, np.ndarray | Variable | Sequence): |
186 | | - raise TypeError( |
187 | | - "Parameter size must be None, an integer, or a sequence with integers." |
188 | | - ) |
189 | 189 | else: |
| 190 | + if not isinstance(shape, Sequence | Variable | np.ndarray): |
| 191 | + raise TypeError( |
| 192 | + "Parameter size must be None, an integer, or a sequence with integers." |
| 193 | + ) |
190 | 194 | shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64") |
191 | 195 |
|
192 | | - if not isinstance(shape, Constant): |
| 196 | + if shape.type.shape == (None,): |
193 | 197 | # This should help ensure that the length of non-constant `size`s |
194 | | - # will be available after certain types of cloning (e.g. the kind |
195 | | - # `Scan` performs) |
| 198 | + # will be available after certain types of cloning (e.g. the kind `Scan` performs) |
196 | 199 | shape = specify_shape(shape, (get_vector_length(shape),)) |
197 | 200 |
|
198 | | - assert not any(s is None for s in shape.type.shape) |
| 201 | + assert shape.type.shape != (None,) |
199 | 202 | assert shape.dtype in int_dtypes |
200 | 203 |
|
201 | 204 | return shape |
|
0 commit comments