|
14 | 14 |
|
15 | 15 | """Helper functions for MCMC, prior and posterior predictive sampling.""" |
16 | 16 |
|
17 | | -from typing import List, Optional, Sequence, Union |
| 17 | +from typing import Optional, Sequence, Union |
18 | 18 |
|
19 | 19 | import numpy as np |
20 | 20 |
|
21 | 21 | from typing_extensions import TypeAlias |
22 | 22 |
|
23 | 23 | from pymc.backends.base import BaseTrace, MultiTrace |
24 | 24 | from pymc.backends.ndarray import NDArray |
25 | | -from pymc.initial_point import PointType |
26 | | -from pymc.vartypes import discrete_types |
27 | 25 |
|
28 | | -ArrayLike: TypeAlias = Union[np.ndarray, List[float]] |
29 | | -PointList: TypeAlias = List[PointType] |
30 | 26 | Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray] |
31 | 27 |
|
32 | 28 | RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] |
|
36 | 32 | __all__ = () |
37 | 33 |
|
38 | 34 |
|
39 | | -def all_continuous(vars): |
40 | | - """Check that vars not include discrete variables, excepting observed RVs.""" |
41 | | - |
42 | | - vars_ = [var for var in vars if not hasattr(var.tag, "observations")] |
43 | | - |
44 | | - if any([(var.dtype in discrete_types) for var in vars_]): |
45 | | - return False |
46 | | - else: |
47 | | - return True |
48 | | - |
49 | | - |
50 | 35 | def _get_seeds_per_chain( |
51 | 36 | random_state: RandomState, |
52 | 37 | chains: int, |
@@ -95,13 +80,3 @@ def _get_unique_seeds_per_chain(integers_fn): |
95 | 80 | ) |
96 | 81 |
|
97 | 82 | return random_state |
98 | | - |
99 | | - |
100 | | -def get_vars_in_point_list(trace, model): |
101 | | - """Get the list of Variable instances in the model that have values stored in the trace.""" |
102 | | - if not isinstance(trace, MultiTrace): |
103 | | - names_in_trace = list(trace[0]) |
104 | | - else: |
105 | | - names_in_trace = trace.varnames |
106 | | - vars_in_trace = [model[v] for v in names_in_trace if v in model] |
107 | | - return vars_in_trace |
0 commit comments