|
29 | 29 | from pytensor.raise_op import Assert |
30 | 30 | from pytensor.scalar import Cast |
31 | 31 | from pytensor.tensor.elemwise import Elemwise |
32 | | -from pytensor.tensor.random import RandomStream |
33 | 32 | from pytensor.tensor.random.basic import IntegersRV |
34 | 33 | from pytensor.tensor.subtensor import AdvancedSubtensor |
35 | 34 | from pytensor.tensor.type import TensorType |
@@ -132,6 +131,12 @@ def __hash__(self): |
132 | 131 | class MinibatchIndexRV(IntegersRV): |
133 | 132 | _print_name = ("minibatch_index", r"\operatorname{minibatch\_index}") |
134 | 133 |
|
| 134 | + # Work-around for https://github.com/pymc-devs/pytensor/issues/97 |
| 135 | + def make_node(self, rng, *args, **kwargs): |
| 136 | + if rng is None: |
| 137 | + rng = pytensor.shared(np.random.default_rng()) |
| 138 | + return super().make_node(rng, *args, **kwargs) |
| 139 | + |
135 | 140 |
|
136 | 141 | minibatch_index = MinibatchIndexRV() |
137 | 142 |
|
@@ -184,10 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: |
184 | 189 | >>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10) |
185 | 190 | """ |
186 | 191 |
|
187 | | - rng = RandomStream() |
188 | 192 | tensor, *tensors = tuple(map(at.as_tensor, (variable, *variables))) |
189 | 193 | upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)]) |
190 | | - slc = rng.gen(minibatch_index, 0, upper, size=batch_size) |
| 194 | + slc = minibatch_index(0, upper, size=batch_size) |
191 | 195 | for i, v in enumerate((tensor, *tensors)): |
192 | 196 | if not valid_for_minibatch(v): |
193 | 197 | raise ValueError( |
|
0 commit comments