Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ dependencies:
- watermark
- polyagamma
- sphinx-remove-toctrees
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ dependencies:
- sphinx>=1.5
- watermark
- sphinx-remove-toctrees
- mypy=0.982
- mypy=0.990
- types-cachetools
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=0.982
- mypy=0.990
- types-cachetools
4 changes: 2 additions & 2 deletions pymc/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
import logging

from typing import Optional
from typing import Dict, List, Optional

import arviz

Expand All @@ -32,7 +32,7 @@
class SamplerReport:
"""Bundle warnings, convergence stats and metadata of a sampling run."""

def __init__(self):
def __init__(self) -> None:
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
self._global_warnings: List[SamplerWarning] = []
self._n_tune = None
Expand Down
12 changes: 7 additions & 5 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

Classes for working with subsets of parameters.
"""
import collections
from __future__ import annotations

from functools import partial
from typing import Callable, Dict, Generic, Optional, TypeVar
from typing import Callable, Dict, Generic, NamedTuple, TypeVar

import numpy as np

Expand All @@ -32,7 +32,9 @@

# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
class RaveledVars(NamedTuple):
data: np.ndarray
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]


class Compose(Generic[T]):
Expand Down Expand Up @@ -69,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars:
@staticmethod
def rmap(
array: RaveledVars,
start_point: Optional[PointType] = None,
start_point: PointType | None = None,
) -> PointType:
"""Map 1D concatenated array to a dictionary of variables in their original spaces.

Expand Down Expand Up @@ -100,7 +102,7 @@ def rmap(

@classmethod
def mapf(
cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None
cls, f: Callable[[PointType], T], start_point: PointType | None = None
) -> Callable[[RaveledVars], T]:
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.

Expand Down
2 changes: 1 addition & 1 deletion pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# Avoid circular dependency when importing modelcontext
from pymc.distributions.distribution import Distribution

assert Distribution # keep both pylint and black happy
_ = Distribution # keep both pylint and black happy
from pymc.model import modelcontext

JITTER_DEFAULT = 1e-6
Expand Down
29 changes: 21 additions & 8 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
import time

from abc import abstractmethod
from collections import namedtuple
from typing import Optional
from typing import Any, NamedTuple

import numpy as np

Expand All @@ -29,20 +30,32 @@
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
from pymc.tuning import guess_scaling
from pymc.util import get_value_vars_from_user_vars

logger = logging.getLogger("pymc")

HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")

DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div")
class DivergenceInfo(NamedTuple):
message: str
exec_info: IntegrationError | None
state: State
state_div: State | None


class HMCStepData(NamedTuple):
end: State
accept_stat: int
divergence_info: DivergenceInfo | None
stats: dict[str, Any]


class BaseHMC(GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo."""

integrator: integration.CpuLeapfrogIntegrator
default_blocked = True

def __init__(
Expand Down Expand Up @@ -138,13 +151,13 @@ def __init__(
self._num_divs_sample = 0

@abstractmethod
def _hamiltonian_step(self, start, p0, step_size):
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
"""Compute one Hamiltonian trajectory and return the next state.

Subclasses must overwrite this abstract method and return an `HMCStepData` object.
"""

def astep(self, q0):
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
"""Perform a single HMC iteration."""
perf_start = time.perf_counter()
process_start = time.process_time()
Expand All @@ -154,6 +167,7 @@ def astep(self, q0):

start = self.integrator.compute_state(q0, p0)

warning: SamplerWarning | None = None
if not np.isfinite(start.energy):
model = self._model
check_test_point_dict = model.point_logps()
Expand Down Expand Up @@ -188,7 +202,6 @@ def astep(self, q0):

self.step_adapt.update(hmc_step.accept_stat, adapt_step)
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
warning: Optional[SamplerWarning] = None
if hmc_step.divergence_info:
info = hmc_step.divergence_info
point = None
Expand Down Expand Up @@ -221,7 +234,7 @@ def astep(self, q0):

self.iter_count += 1

stats = {
stats: dict[str, Any] = {
"tune": self.tune,
"diverging": bool(hmc_step.divergence_info),
"perf_counter_diff": perf_end - perf_start,
Expand Down
8 changes: 6 additions & 2 deletions pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

import numpy as np

from pymc.stats.convergence import SamplerWarning
Expand Down Expand Up @@ -119,7 +123,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs):
self.path_length = path_length
self.max_steps = max_steps

def _hamiltonian_step(self, start, p0, step_size):
def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData:
n_steps = max(1, int(self.path_length / step_size))
n_steps = min(self.max_steps, n_steps)

Expand Down Expand Up @@ -156,7 +160,7 @@ def _hamiltonian_step(self, start, p0, step_size):
end = state
accepted = True

stats = {
stats: dict[str, Any] = {
"path_length": self.path_length,
"n_steps": n_steps,
"accept": accept_stat,
Expand Down
19 changes: 14 additions & 5 deletions pymc/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
from typing import NamedTuple

import numpy as np

from scipy import linalg

from pymc.blocking import RaveledVars
from pymc.step_methods.hmc.quadpotential import QuadPotential

State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory")

class State(NamedTuple):
q: RaveledVars
p: RaveledVars
v: np.ndarray
q_grad: np.ndarray
energy: float
model_logp: float
index_in_trajectory: int


class IntegrationError(RuntimeError):
pass


class CpuLeapfrogIntegrator:
def __init__(self, potential, logp_dlogp_func):
def __init__(self, potential: QuadPotential, logp_dlogp_func):
"""Leapfrog integrator using CPU."""
self._potential = potential
self._logp_dlogp_func = logp_dlogp_func
Expand All @@ -39,14 +48,14 @@ def __init__(self, potential, logp_dlogp_func):
"don't match." % (self._potential.dtype, self._dtype)
)

def compute_state(self, q, p):
def compute_state(self, q: RaveledVars, p: RaveledVars):
"""Compute Hamiltonian functions using a position and momentum."""
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
raise ValueError("Invalid dtype. Must be %s" % self._dtype)

logp, dlogp = self._logp_dlogp_func(q)

v = self._potential.velocity(p.data)
v = self._potential.velocity(p.data, out=None)
kinetic = self._potential.energy(p.data, velocity=v)
energy = kinetic - logp
return State(q, p, v, dlogp, energy, logp, 0)
Expand Down
31 changes: 22 additions & 9 deletions pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import namedtuple

import numpy as np
Expand All @@ -20,8 +22,9 @@
from pymc.math import logbern
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.vartypes import continuous_types

__all__ = ["NUTS"]
Expand Down Expand Up @@ -227,7 +230,14 @@ def competence(var, has_grad):


class _Tree:
def __init__(self, ndim, integrator, start, step_size, Emax):
def __init__(
self,
ndim: int,
integrator: integration.CpuLeapfrogIntegrator,
start: State,
step_size: float,
Emax: float,
):
"""Binary tree from the NUTS algorithm.

Parameters
Expand All @@ -247,17 +257,17 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
self.start = start
self.step_size = step_size
self.Emax = Emax
self.start_energy = np.array(start.energy)
self.start_energy = start.energy

self.left = self.right = start
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
self.depth = 0
self.log_size = 0
self.log_size = 0.0
self.log_accept_sum = -np.inf
self.mean_tree_accept = 0.0
self.n_proposals = 0
self.p_sum = start.p.data.copy()
self.max_energy_change = 0
self.max_energy_change = 0.0

def extend(self, direction):
"""Double the treesize by extending the tree in the given direction.
Expand Down Expand Up @@ -315,16 +325,19 @@ def extend(self, direction):

return diverging, turning

def _single_step(self, left, epsilon):
def _single_step(self, left: State, epsilon: float):
"""Perform a leapfrog step and handle error cases."""
right: State | None
error: IntegrationError | None
error_msg: str | None
try:
# `State` type
right = self.integrator.step(epsilon, left)
except IntegrationError as err:
error_msg = str(err)
error = err
right = None
else:
assert right is not None # since there was no IntegrationError
# h - H0
energy_change = right.energy - self.start_energy
if np.isnan(energy_change):
Expand Down Expand Up @@ -354,8 +367,8 @@ def _single_step(self, left, epsilon):
finally:
self.n_proposals += 1
tree = Subtree(None, None, None, None, -np.inf)
divergance_info = DivergenceInfo(error_msg, error, left, right)
return tree, divergance_info, False
divergence_info = DivergenceInfo(error_msg, error, left, right)
return tree, divergence_info, False

def _build_subtree(self, left, depth, epsilon):
if depth == 0:
Expand Down
16 changes: 15 additions & 1 deletion pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

from typing import overload

import aesara
import numpy as np
import scipy.linalg
Expand Down Expand Up @@ -94,7 +98,17 @@ def __str__(self):


class QuadPotential:
def velocity(self, x, out=None):
dtype: np.dtype

@overload
def velocity(self, x: np.ndarray, out: None) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't out get a real type here? Because None is not a type..

Copy link
Contributor Author

@maresb maresb Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're technically correct. Although the singleton None has type NoneType, I believe that type checkers automatically do an implicit conversion here. (Who really wants to remember to type the extra letters and be so pedantic? There's no potential for ambiguity since it's a singleton type.)

If you check the mypy docs, they always use None, so it's not a problem.

...

@overload
def velocity(self, x: np.ndarray, out: np.ndarray) -> None:
...

def velocity(self, x: np.ndarray, out: np.ndarray | None = None) -> np.ndarray | None:
"""Compute the current velocity at a position in parameter space."""
raise NotImplementedError("Abstract method")

Expand Down
Loading