File tree Expand file tree Collapse file tree 4 files changed +24
-0
lines changed Expand file tree Collapse file tree 4 files changed +24
-0
lines changed Original file line number Diff line number Diff line change 3737warnings .warn ("This module is experimental." )
3838
3939
40+ __all__ = (
41+ "get_jaxified_graph" ,
42+ "get_jaxified_logp" ,
43+ "sample_blackjax_nuts" ,
44+ "sample_numpyro_nuts" ,
45+ )
46+
47+
4048@jax_funcify .register (Assert )
4149@jax_funcify .register (CheckParameterValue )
4250@jax_funcify .register (SpecifyShape )
Original file line number Diff line number Diff line change 1+ # This file exists only for backward-compatibility with imports like
2+ # `import pymc.sampling_jax` or `from pymc import sampling_jax`.
3+
4+ # pylint: disable=wildcard-import
5+ # pylint: disable=unused-wildcard-import
6+
7+ from pymc .sampling .jax import *
Original file line number Diff line number Diff line change 1616
1717import pymc as pm
1818
19+
20+ def test_old_import_route ():
21+ import pymc .sampling .jax as new_sj
22+ import pymc .sampling_jax as old_sj
23+
24+ assert set (new_sj .__all__ ) <= set (dir (old_sj ))
25+
26+
1927with pytest .warns (UserWarning , match = "module is experimental" ):
2028 from pymc .sampling .jax import (
2129 _get_batched_jittered_initial_points ,
Original file line number Diff line number Diff line change 5050pymc/ode/ode.py
5151pymc/ode/utils.py
5252pymc/plots/__init__.py
53+ pymc/sampling_jax.py
5354pymc/sampling/__init__.py
5455pymc/sampling/forward.py
5556pymc/sampling/mcmc.py
You can’t perform that action at this time.
0 commit comments