diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 5ceef53dfd..fd56a26372 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -38,5 +38,5 @@ dependencies: - watermark - polyagamma - sphinx-remove-toctrees -- mypy=0.982 +- mypy=0.990 - types-cachetools diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 19fcd7ecd4..58ad3be53a 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -27,5 +27,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.982 +- mypy=0.990 - types-cachetools diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index c5aea815e5..dc3a80fcdd 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -35,5 +35,5 @@ dependencies: - sphinx>=1.5 - watermark - sphinx-remove-toctrees -- mypy=0.982 +- mypy=0.990 - types-cachetools diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index b27c453284..b1515cb7b7 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -28,5 +28,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.982 +- mypy=0.990 - types-cachetools diff --git a/pymc/backends/report.py b/pymc/backends/report.py index 11cdf93139..b336eeb58c 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -15,7 +15,7 @@ import dataclasses import logging -from typing import Optional +from typing import Dict, List, Optional import arviz @@ -32,7 +32,7 @@ class SamplerReport: """Bundle warnings, convergence stats and metadata of a sampling run.""" - def __init__(self): + def __init__(self) -> None: self._chain_warnings: Dict[int, List[SamplerWarning]] = {} self._global_warnings: List[SamplerWarning] = [] self._n_tune = None diff --git a/pymc/blocking.py b/pymc/blocking.py index bd58cc38a3..15e249d51e 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -17,10 +17,10 @@ Classes for working with subsets of parameters. """ -import collections +from __future__ import annotations from functools import partial -from typing import Callable, Dict, Generic, Optional, TypeVar +from typing import Callable, Dict, Generic, NamedTuple, TypeVar import numpy as np @@ -32,7 +32,9 @@ # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for # each of the raveled variables. -RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info") +class RaveledVars(NamedTuple): + data: np.ndarray + point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...] class Compose(Generic[T]): @@ -69,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars: @staticmethod def rmap( array: RaveledVars, - start_point: Optional[PointType] = None, + start_point: PointType | None = None, ) -> PointType: """Map 1D concatenated array to a dictionary of variables in their original spaces. @@ -100,7 +102,7 @@ def rmap( @classmethod def mapf( - cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None + cls, f: Callable[[PointType], T], start_point: PointType | None = None ) -> Callable[[RaveledVars], T]: """Create a callable that first maps back to ``dict`` inputs and then applies a function. diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 4c68bb737c..68dc46cd9e 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -31,7 +31,7 @@ # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import Distribution -assert Distribution # keep both pylint and black happy +_ = Distribution # keep both pylint and black happy from pymc.model import modelcontext JITTER_DEFAULT = 1e-6 diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 6ccf063b68..0eb6fd77b1 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import time from abc import abstractmethod -from collections import namedtuple -from typing import Optional +from typing import Any, NamedTuple import numpy as np @@ -29,20 +30,32 @@ from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.hmc import integration +from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential from pymc.tuning import guess_scaling from pymc.util import get_value_vars_from_user_vars logger = logging.getLogger("pymc") -HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats") -DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div") +class DivergenceInfo(NamedTuple): + message: str + exec_info: IntegrationError | None + state: State + state_div: State | None + + +class HMCStepData(NamedTuple): + end: State + accept_stat: int + divergence_info: DivergenceInfo | None + stats: dict[str, Any] class BaseHMC(GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" + integrator: integration.CpuLeapfrogIntegrator default_blocked = True def __init__( @@ -138,13 +151,13 @@ def __init__( self._num_divs_sample = 0 @abstractmethod - def _hamiltonian_step(self, start, p0, step_size): + def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: """Compute one Hamiltonian trajectory and return the next state. Subclasses must overwrite this abstract method and return an `HMCStepData` object. """ - def astep(self, q0): + def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]: """Perform a single HMC iteration.""" perf_start = time.perf_counter() process_start = time.process_time() @@ -154,6 +167,7 @@ def astep(self, q0): start = self.integrator.compute_state(q0, p0) + warning: SamplerWarning | None = None if not np.isfinite(start.energy): model = self._model check_test_point_dict = model.point_logps() @@ -188,7 +202,6 @@ def astep(self, q0): self.step_adapt.update(hmc_step.accept_stat, adapt_step) self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) - warning: Optional[SamplerWarning] = None if hmc_step.divergence_info: info = hmc_step.divergence_info point = None @@ -221,7 +234,7 @@ def astep(self, q0): self.iter_count += 1 - stats = { + stats: dict[str, Any] = { "tune": self.tune, "diverging": bool(hmc_step.divergence_info), "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 4ed192ac99..804fadc304 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Any + import numpy as np from pymc.stats.convergence import SamplerWarning @@ -119,7 +123,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs): self.path_length = path_length self.max_steps = max_steps - def _hamiltonian_step(self, start, p0, step_size): + def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData: n_steps = max(1, int(self.path_length / step_size)) n_steps = min(self.max_steps, n_steps) @@ -156,7 +160,7 @@ def _hamiltonian_step(self, start, p0, step_size): end = state accepted = True - stats = { + stats: dict[str, Any] = { "path_length": self.path_length, "n_steps": n_steps, "accept": accept_stat, diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index f631777010..c2dc7cf191 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import NamedTuple import numpy as np from scipy import linalg from pymc.blocking import RaveledVars +from pymc.step_methods.hmc.quadpotential import QuadPotential -State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory") + +class State(NamedTuple): + q: RaveledVars + p: RaveledVars + v: np.ndarray + q_grad: np.ndarray + energy: float + model_logp: float + index_in_trajectory: int class IntegrationError(RuntimeError): @@ -28,7 +37,7 @@ class IntegrationError(RuntimeError): class CpuLeapfrogIntegrator: - def __init__(self, potential, logp_dlogp_func): + def __init__(self, potential: QuadPotential, logp_dlogp_func): """Leapfrog integrator using CPU.""" self._potential = potential self._logp_dlogp_func = logp_dlogp_func @@ -39,14 +48,14 @@ def __init__(self, potential, logp_dlogp_func): "don't match." % (self._potential.dtype, self._dtype) ) - def compute_state(self, q, p): + def compute_state(self, q: RaveledVars, p: RaveledVars): """Compute Hamiltonian functions using a position and momentum.""" if q.data.dtype != self._dtype or p.data.dtype != self._dtype: raise ValueError("Invalid dtype. Must be %s" % self._dtype) logp, dlogp = self._logp_dlogp_func(q) - v = self._potential.velocity(p.data) + v = self._potential.velocity(p.data, out=None) kinetic = self._potential.energy(p.data, velocity=v) energy = kinetic - logp return State(q, p, v, dlogp, energy, logp, 0) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 0eed003da1..f1251d5c55 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import namedtuple import numpy as np @@ -20,8 +22,9 @@ from pymc.math import logbern from pymc.stats.convergence import SamplerWarning from pymc.step_methods.arraystep import Competence +from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData -from pymc.step_methods.hmc.integration import IntegrationError +from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.vartypes import continuous_types __all__ = ["NUTS"] @@ -227,7 +230,14 @@ def competence(var, has_grad): class _Tree: - def __init__(self, ndim, integrator, start, step_size, Emax): + def __init__( + self, + ndim: int, + integrator: integration.CpuLeapfrogIntegrator, + start: State, + step_size: float, + Emax: float, + ): """Binary tree from the NUTS algorithm. Parameters @@ -247,17 +257,17 @@ def __init__(self, ndim, integrator, start, step_size, Emax): self.start = start self.step_size = step_size self.Emax = Emax - self.start_energy = np.array(start.energy) + self.start_energy = start.energy self.left = self.right = start self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) self.depth = 0 - self.log_size = 0 + self.log_size = 0.0 self.log_accept_sum = -np.inf self.mean_tree_accept = 0.0 self.n_proposals = 0 self.p_sum = start.p.data.copy() - self.max_energy_change = 0 + self.max_energy_change = 0.0 def extend(self, direction): """Double the treesize by extending the tree in the given direction. @@ -315,16 +325,19 @@ def extend(self, direction): return diverging, turning - def _single_step(self, left, epsilon): + def _single_step(self, left: State, epsilon: float): """Perform a leapfrog step and handle error cases.""" + right: State | None + error: IntegrationError | None + error_msg: str | None try: - # `State` type right = self.integrator.step(epsilon, left) except IntegrationError as err: error_msg = str(err) error = err right = None else: + assert right is not None # since there was no IntegrationError # h - H0 energy_change = right.energy - self.start_energy if np.isnan(energy_change): @@ -354,8 +367,8 @@ def _single_step(self, left, epsilon): finally: self.n_proposals += 1 tree = Subtree(None, None, None, None, -np.inf) - divergance_info = DivergenceInfo(error_msg, error, left, right) - return tree, divergance_info, False + divergence_info = DivergenceInfo(error_msg, error, left, right) + return tree, divergence_info, False def _build_subtree(self, left, depth, epsilon): if depth == 0: diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index d429b56b21..f79f0e4302 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings +from typing import overload + import aesara import numpy as np import scipy.linalg @@ -94,7 +98,17 @@ def __str__(self): class QuadPotential: - def velocity(self, x, out=None): + dtype: np.dtype + + @overload + def velocity(self, x: np.ndarray, out: None) -> np.ndarray: + ... + + @overload + def velocity(self, x: np.ndarray, out: np.ndarray) -> None: + ... + + def velocity(self, x: np.ndarray, out: np.ndarray | None = None) -> np.ndarray | None: """Compute the current velocity at a position in parameter space.""" raise NotImplementedError("Abstract method") diff --git a/pymc/variational/callbacks.py b/pymc/variational/callbacks.py index 7e0f831083..faa343f358 100644 --- a/pymc/variational/callbacks.py +++ b/pymc/variational/callbacks.py @@ -14,6 +14,8 @@ import collections +from typing import Callable, Dict + import numpy as np __all__ = ["Callback", "CheckParametersConvergence", "Tracker"] @@ -24,15 +26,19 @@ def __call__(self, approx, loss, i): raise NotImplementedError -def relative(current, prev, eps=1e-6): - return (np.abs(current - prev) + eps) / (np.abs(prev) + eps) +def relative(current: np.ndarray, prev: np.ndarray, eps=1e-6) -> np.ndarray: + diff = current - prev # type: ignore + return (np.abs(diff) + eps) / (np.abs(prev) + eps) -def absolute(current, prev): - return np.abs(current - prev) +def absolute(current: np.ndarray, prev: np.ndarray) -> np.ndarray: + diff = current - prev # type: ignore + return np.abs(diff) -_diff = dict(relative=relative, absolute=absolute) +_diff: Dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = dict( + relative=relative, absolute=absolute +) class CheckParametersConvergence(Callback): @@ -68,7 +74,7 @@ def __init__(self, every=100, tolerance=1e-3, diff="relative", ord=np.inf): self.prev = None self.tolerance = tolerance - def __call__(self, approx, _, i): + def __call__(self, approx, _, i) -> None: if self.prev is None: self.prev = self.flatten_shared(approx.params) return @@ -76,7 +82,7 @@ def __call__(self, approx, _, i): return current = self.flatten_shared(approx.params) prev = self.prev - delta = self._diff(current, prev) # type: np.ndarray + delta: np.ndarray = self._diff(current, prev) self.prev = current norm = np.linalg.norm(delta, self.ord) if norm < self.tolerance: diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 2c67b787b4..9192ee4b25 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import aesara -from aesara import tensor as at +from aesara.graph.basic import Variable import pymc as pm @@ -70,14 +72,16 @@ class KSDObjective(ObjectiveFunction): OPVI TestFunction """ - def __init__(self, op, tf): + op: KSD + + def __init__(self, op: KSD, tf: opvi.TestFunction): if not isinstance(op, KSD): raise opvi.ParametrizationError("Op should be KSD") super().__init__(op, tf) @aesara.config.change_flags(compute_test_value="off") - def __call__(self, nmc, **kwargs): - op = self.op # type: KSD + def __call__(self, nmc, **kwargs) -> list[Variable]: + op: KSD = self.op grad = op.apply(self.tf) if self.approx.all_histograms: z = self.approx.joint_histogram @@ -88,7 +92,7 @@ def __call__(self, nmc, **kwargs): else: params = self.test_params + kwargs["more_tf_params"] grad *= pm.floatX(-1) - grads = at.grad(None, params, known_grads={z: grad}) + grads = aesara.grad(None, params, known_grads={z: grad}) return self.approx.set_size_and_deterministic( grads, nmc, 0, kwargs.get("more_replacements") ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 1b3c331a8f..9002b381c2 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -45,6 +45,8 @@ https://arxiv.org/abs/1610.09033 (2016) """ +from __future__ import annotations + import collections import itertools import warnings @@ -181,7 +183,7 @@ class ObjectiveFunction: OPVI TestFunction """ - def __init__(self, op, tf): + def __init__(self, op: Operator, tf: TestFunction): self.op = op self.tf = tf @@ -962,7 +964,9 @@ def symbolic_random(self): raise NotImplementedError @aesara.config.change_flags(compute_test_value="off") - def set_size_and_deterministic(self, node, s, d, more_replacements=None): + def set_size_and_deterministic( + self, node: Variable, s, d: bool, more_replacements: dict | None = None + ) -> list[Variable]: """*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or :func:`symbolic_single_sample` new random generator can be allocated and applied to node diff --git a/requirements-dev.txt b/requirements-dev.txt index 09819e6317..785b368998 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,7 @@ fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 jupyter-sphinx -mypy==0.982 +mypy==0.990 myst-nb numpy>=1.15.0 numpydoc