Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ dependencies:
- watermark
- polyagamma
- sphinx-remove-toctrees
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ dependencies:
- sphinx>=1.5
- watermark
- sphinx-remove-toctrees
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.982
- mypy=0.990
- types-cachetools
98 changes: 82 additions & 16 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import (
Callable,
Dict,
Expand All @@ -31,6 +33,7 @@
import scipy.sparse as sps

from aeppl.logprob import CheckParameterValue
from aeppl.transforms import RVTransform
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
Expand Down Expand Up @@ -205,10 +208,9 @@ def expand(var):
yield from walk(graphs, expand, bfs=False)


def replace_rvs_in_graphs(
def _replace_rvs_in_graphs(
graphs: Iterable[TensorVariable],
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
**kwargs,
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
"""Replace random variables in graphs
Expand All @@ -226,8 +228,6 @@ def replace_rvs_in_graphs(
that were made.
"""
replacements = {}
if initial_replacements:
replacements.update(initial_replacements)

def expand_replace(var):
new_nodes = []
Expand All @@ -239,6 +239,7 @@ def expand_replace(var):
new_nodes.extend(replacement_fn(var, replacements))
return new_nodes

# This iteration populates the replacements
for var in walk_model(graphs, expand_fn=expand_replace, **kwargs):
pass

Expand All @@ -253,7 +254,15 @@ def expand_replace(var):
clone=False,
)

fg.replace_all(replacements.items(), import_missing=True)
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
toposort = fg.toposort()
sorted_replacements = sorted(
tuple(replacements.items()),
key=lambda pair: toposort.index(pair[0].owner),
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)

graphs = list(fg.outputs)

Expand All @@ -263,7 +272,6 @@ def expand_replace(var):
def rvs_to_value_vars(
graphs: Iterable[Variable],
apply_transforms: bool = True,
initial_replacements: Optional[Dict[Variable, Variable]] = None,
**kwargs,
) -> List[Variable]:
"""Clone and replace random variables in graphs with their value variables.
Expand All @@ -276,10 +284,11 @@ def rvs_to_value_vars(
The graphs in which to perform the replacements.
apply_transforms
If ``True``, apply each value variable's transform.
initial_replacements
A ``dict`` containing the initial replacements to be made.

"""
warnings.warn(
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
FutureWarning,
)

def populate_replacements(
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
Expand Down Expand Up @@ -311,15 +320,72 @@ def populate_replacements(
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

if initial_replacements:
initial_replacements = {
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
}

graphs, _ = replace_rvs_in_graphs(
graphs, _ = _replace_rvs_in_graphs(
graphs,
replacement_fn=populate_replacements,
initial_replacements=initial_replacements,
**kwargs,
)

return graphs


def replace_rvs_by_values(
graphs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
**kwargs,
) -> List[TensorVariable]:
"""Clone and replace random variables in graphs with their value variables.

This will *not* recompute test values in the resulting graphs.

Parameters
----------
graphs
The graphs in which to perform the replacements.
rvs_to_values
Mapping between the original graph RVs and respective value variables
rvs_to_transforms
Mapping between the original graph RVs and respective value transforms
"""

# Clone original graphs so that we don't modify variables in place
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

# Get needed mappings for equivalent cloned variables
equiv_rvs_to_values = {}
equiv_rvs_to_transforms = {}
for rv, value in rvs_to_values.items():
equiv_rv = equiv.get(rv, rv)
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]

def poulate_replacements(rv, replacements):
# Populate replacements dict with {rv: value} pairs indicating which graph
# RVs should be replaced by what value variables.

# No value variable to replace RV with
value = equiv_rvs_to_values.get(rv, None)
if value is None:
return []

transform = equiv_rvs_to_transforms.get(rv, None)
if transform is not None:
# We want to replace uses of the RV by the back-transformation of its value
value = transform.backward(value, *rv.owner.inputs)
value.name = rv.name

replacements[rv] = value
# Also walk the graph of the value variable to make any additional
# replacements if that is not a simple input variable
return [value]

graphs, _ = _replace_rvs_in_graphs(
graphs,
replacement_fn=poulate_replacements,
**kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
aux_obs = model.rvs_to_values.get(obs, None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
Expand Down Expand Up @@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):

if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
try:
obs_data = extract_obs_data(var.tag.observations)
obs_data = extract_obs_data(self.model.rvs_to_values[var])
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {var}")

Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
import logging

from typing import Optional
from typing import Dict, List, Optional

import arviz

Expand All @@ -32,7 +32,7 @@
class SamplerReport:
"""Bundle warnings, convergence stats and metadata of a sampling run."""

def __init__(self):
def __init__(self) -> None:
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
self._global_warnings: List[SamplerWarning] = []
self._n_tune = None
Expand Down
12 changes: 7 additions & 5 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

Classes for working with subsets of parameters.
"""
import collections
from __future__ import annotations

from functools import partial
from typing import Callable, Dict, Generic, Optional, TypeVar
from typing import Callable, Dict, Generic, NamedTuple, TypeVar

import numpy as np

Expand All @@ -32,7 +32,9 @@

# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
class RaveledVars(NamedTuple):
data: np.ndarray
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]


class Compose(Generic[T]):
Expand Down Expand Up @@ -69,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars:
@staticmethod
def rmap(
array: RaveledVars,
start_point: Optional[PointType] = None,
start_point: PointType | None = None,
) -> PointType:
"""Map 1D concatenated array to a dictionary of variables in their original spaces.

Expand Down Expand Up @@ -100,7 +102,7 @@ def rmap(

@classmethod
def mapf(
cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None
cls, f: Callable[[PointType], T], start_point: PointType | None = None
) -> Callable[[RaveledVars], T]:
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.

Expand Down
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
logcdf,
logp,
joint_logp,
joint_logpt,
)

from pymc.distributions.bound import Bound
Expand Down Expand Up @@ -199,7 +198,6 @@
"Censored",
"CAR",
"PolyaGamma",
"joint_logpt",
"joint_logp",
"logp",
"logcdf",
Expand Down
14 changes: 9 additions & 5 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,17 +1140,18 @@ def rng_fn_scipy(cls, rng, lower, upper, size=None):
return stats.randint.rvs(lower, upper + 1, size=size, random_state=rng)


discrete_uniform = DiscreteUniformRV()
discrete_uniform = RV()


class DiscreteUniform(Discrete):
R"""
Discrete uniform distribution.
R"""Discrete uniform distribution.

The pmf of this distribution is

.. math:: f(x \mid lower, upper) = \frac{1}{upper-lower+1}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -1177,9 +1178,9 @@ class DiscreteUniform(Discrete):

Parameters
----------
lower: int
lower : tensor_like of int
Lower limit.
upper: int
upper : tensor_like of int
Upper limit (upper > lower).
"""

Expand Down Expand Up @@ -1256,6 +1257,7 @@ class Categorical(Discrete):
.. math:: f(x \mid p) = p_x

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1549,6 +1551,7 @@ class ZeroInflatedBinomial:
\end{array} \right.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1624,6 +1627,7 @@ class ZeroInflatedNegativeBinomial:
\right.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
shape_from_dims,
)
from pymc.printing import str_for_dist
from pymc.util import UNSET
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import string_types

__all__ = [
Expand Down Expand Up @@ -371,6 +371,7 @@ def dist(
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
_add_future_warning_tag(rv_out)
return rv_out


Expand Down
Loading