|
15 | 15 | import re |
16 | 16 | import sys |
17 | 17 |
|
| 18 | +from datetime import datetime |
18 | 19 | from functools import partial |
19 | 20 | from typing import Any, Callable, Dict, List, Optional, Sequence, Union |
20 | 21 |
|
21 | | -from pytensor.tensor.random.type import RandomType |
22 | | - |
23 | | -from pymc.initial_point import StartDict |
24 | | -from pymc.sampling.mcmc import _init_jitter |
25 | | - |
26 | | -xla_flags = os.getenv("XLA_FLAGS", "") |
27 | | -xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() |
28 | | -os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) |
29 | | - |
30 | | -from datetime import datetime |
31 | | - |
32 | 22 | import arviz as az |
33 | 23 | import jax |
34 | 24 | import numpy as np |
|
43 | 33 | from pytensor.link.jax.dispatch import jax_funcify |
44 | 34 | from pytensor.raise_op import Assert |
45 | 35 | from pytensor.tensor import TensorVariable |
| 36 | +from pytensor.tensor.random.type import RandomType |
46 | 37 | from pytensor.tensor.shape import SpecifyShape |
47 | 38 |
|
48 | 39 | from pymc import Model, modelcontext |
49 | 40 | from pymc.backends.arviz import find_constants, find_observations |
| 41 | +from pymc.initial_point import StartDict |
50 | 42 | from pymc.logprob.utils import CheckParameterValue |
| 43 | +from pymc.sampling.mcmc import _init_jitter |
51 | 44 | from pymc.util import ( |
52 | 45 | RandomSeed, |
53 | 46 | RandomState, |
54 | 47 | _get_seeds_per_chain, |
55 | 48 | get_default_varnames, |
56 | 49 | ) |
57 | 50 |
|
| 51 | +xla_flags_env = os.getenv("XLA_FLAGS", "") |
| 52 | +xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split() |
| 53 | +os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) |
| 54 | + |
58 | 55 | __all__ = ( |
59 | 56 | "get_jaxified_graph", |
60 | 57 | "get_jaxified_logp", |
@@ -111,7 +108,7 @@ def get_jaxified_graph( |
111 | 108 | ) -> List[TensorVariable]: |
112 | 109 | """Compile an PyTensor graph into an optimized JAX function""" |
113 | 110 |
|
114 | | - graph = _replace_shared_variables(outputs) |
| 111 | + graph = _replace_shared_variables(outputs) if outputs is not None else None |
115 | 112 |
|
116 | 113 | fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) |
117 | 114 | # We need to add a Supervisor to the fgraph to be able to run the |
@@ -254,12 +251,10 @@ def _get_batched_jittered_initial_points( |
254 | 251 | jitter=jitter, |
255 | 252 | jitter_max_retries=jitter_max_retries, |
256 | 253 | ) |
257 | | - initial_points = [list(initial_point.values()) for initial_point in initial_points] |
| 254 | + initial_points_values = [list(initial_point.values()) for initial_point in initial_points] |
258 | 255 | if chains == 1: |
259 | | - initial_points = initial_points[0] |
260 | | - else: |
261 | | - initial_points = [np.stack(init_state) for init_state in zip(*initial_points)] |
262 | | - return initial_points |
| 256 | + return initial_points_values[0] |
| 257 | + return [np.stack(init_state) for init_state in zip(*initial_points_values)] |
263 | 258 |
|
264 | 259 |
|
265 | 260 | def _update_coords_and_dims( |
|
0 commit comments