From b106c79484dadf9a2994f6d7af7be263166d9ade Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 19:09:00 +0100 Subject: [PATCH 01/17] Add a few missing type imports --- pymc/backends/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/backends/report.py b/pymc/backends/report.py index 11cdf93139..a7fd1ebcd4 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -15,7 +15,7 @@ import dataclasses import logging -from typing import Optional +from typing import Dict, List, Optional import arviz From 8fbc8af7fdadec465257793d6b78c64f12e4131a Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 19:19:57 +0100 Subject: [PATCH 02/17] Trade assert with assignment to keep mypy happy --- pymc/gp/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 4c68bb737c..68dc46cd9e 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -31,7 +31,7 @@ # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import Distribution -assert Distribution # keep both pylint and black happy +_ = Distribution # keep both pylint and black happy from pymc.model import modelcontext JITTER_DEFAULT = 1e-6 From a096474f3f435efc920f4caf578bb380b28546d6 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 21:51:15 +0100 Subject: [PATCH 03/17] Add a few type annotations --- pymc/variational/callbacks.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pymc/variational/callbacks.py b/pymc/variational/callbacks.py index 7e0f831083..32e3ff69ff 100644 --- a/pymc/variational/callbacks.py +++ b/pymc/variational/callbacks.py @@ -14,6 +14,8 @@ import collections +from typing import Callable, Dict + import numpy as np __all__ = ["Callback", "CheckParametersConvergence", "Tracker"] @@ -24,15 +26,19 @@ def __call__(self, approx, loss, i): raise NotImplementedError -def relative(current, prev, eps=1e-6): - return (np.abs(current - prev) + eps) / (np.abs(prev) + eps) +def relative(current: np.ndarray, prev: np.ndarray, eps=1e-6) -> np.ndarray: + diff = current - prev # type: ignore + return (np.abs(diff) + eps) / (np.abs(prev) + eps) -def absolute(current, prev): - return np.abs(current - prev) +def absolute(current: np.ndarray, prev: np.ndarray) -> np.ndarray: + diff = current - prev # type: ignore + return np.abs(diff) -_diff = dict(relative=relative, absolute=absolute) +_diff: Dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = dict( + relative=relative, absolute=absolute +) class CheckParametersConvergence(Callback): @@ -76,7 +82,7 @@ def __call__(self, approx, _, i): return current = self.flatten_shared(approx.params) prev = self.prev - delta = self._diff(current, prev) # type: np.ndarray + delta: np.ndarray = self._diff(current, prev) self.prev = current norm = np.linalg.norm(delta, self.ord) if norm < self.tolerance: From 15e9a43327beea17e772b01b653d51bddd42a5e8 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 21:51:46 +0100 Subject: [PATCH 04/17] Add missing return type for __call__ --- pymc/variational/callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/callbacks.py b/pymc/variational/callbacks.py index 32e3ff69ff..faa343f358 100644 --- a/pymc/variational/callbacks.py +++ b/pymc/variational/callbacks.py @@ -74,7 +74,7 @@ def __init__(self, every=100, tolerance=1e-3, diff="relative", ord=np.inf): self.prev = None self.tolerance = tolerance - def __call__(self, approx, _, i): + def __call__(self, approx, _, i) -> None: if self.prev is None: self.prev = self.flatten_shared(approx.params) return From ab61f5cde772070778be0f2106dbdfe1ad1600e8 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 21:53:37 +0100 Subject: [PATCH 05/17] Switch comment type declaration to raw --- pymc/variational/operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 2c67b787b4..6de2a1f28f 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -77,7 +77,7 @@ def __init__(self, op, tf): @aesara.config.change_flags(compute_test_value="off") def __call__(self, nmc, **kwargs): - op = self.op # type: KSD + op: KSD = self.op grad = op.apply(self.tf) if self.approx.all_histograms: z = self.approx.joint_histogram From fde69191dcf2d5633a25948068954f08e96f70dd Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 22:17:23 +0100 Subject: [PATCH 06/17] Get operators.py to pass --- pymc/variational/operators.py | 12 ++++++++---- pymc/variational/opvi.py | 8 ++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 6de2a1f28f..9192ee4b25 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import aesara -from aesara import tensor as at +from aesara.graph.basic import Variable import pymc as pm @@ -70,13 +72,15 @@ class KSDObjective(ObjectiveFunction): OPVI TestFunction """ - def __init__(self, op, tf): + op: KSD + + def __init__(self, op: KSD, tf: opvi.TestFunction): if not isinstance(op, KSD): raise opvi.ParametrizationError("Op should be KSD") super().__init__(op, tf) @aesara.config.change_flags(compute_test_value="off") - def __call__(self, nmc, **kwargs): + def __call__(self, nmc, **kwargs) -> list[Variable]: op: KSD = self.op grad = op.apply(self.tf) if self.approx.all_histograms: @@ -88,7 +92,7 @@ def __call__(self, nmc, **kwargs): else: params = self.test_params + kwargs["more_tf_params"] grad *= pm.floatX(-1) - grads = at.grad(None, params, known_grads={z: grad}) + grads = aesara.grad(None, params, known_grads={z: grad}) return self.approx.set_size_and_deterministic( grads, nmc, 0, kwargs.get("more_replacements") ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 1b3c331a8f..9002b381c2 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -45,6 +45,8 @@ https://arxiv.org/abs/1610.09033 (2016) """ +from __future__ import annotations + import collections import itertools import warnings @@ -181,7 +183,7 @@ class ObjectiveFunction: OPVI TestFunction """ - def __init__(self, op, tf): + def __init__(self, op: Operator, tf: TestFunction): self.op = op self.tf = tf @@ -962,7 +964,9 @@ def symbolic_random(self): raise NotImplementedError @aesara.config.change_flags(compute_test_value="off") - def set_size_and_deterministic(self, node, s, d, more_replacements=None): + def set_size_and_deterministic( + self, node: Variable, s, d: bool, more_replacements: dict | None = None + ) -> list[Variable]: """*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or :func:`symbolic_single_sample` new random generator can be allocated and applied to node From 77130eba64ded9827a66c60b7439dd3cfd6f65a4 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 22:18:33 +0100 Subject: [PATCH 07/17] Fix pymc.backends.report --- pymc/backends/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/backends/report.py b/pymc/backends/report.py index a7fd1ebcd4..b336eeb58c 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -32,7 +32,7 @@ class SamplerReport: """Bundle warnings, convergence stats and metadata of a sampling run.""" - def __init__(self): + def __init__(self) -> None: self._chain_warnings: Dict[int, List[SamplerWarning]] = {} self._global_warnings: List[SamplerWarning] = [] self._n_tune = None From e396635023c2b8d1e4cf9c74c03bf3381884f2b3 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 23:53:28 +0100 Subject: [PATCH 08/17] Fix a bunch of typing issues --- pymc/blocking.py | 7 ++++--- pymc/step_methods/hmc/base_hmc.py | 27 ++++++++++++++++++-------- pymc/step_methods/hmc/hmc.py | 6 ++++-- pymc/step_methods/hmc/integration.py | 19 +++++++++++++----- pymc/step_methods/hmc/nuts.py | 23 ++++++++++++++++------ pymc/step_methods/hmc/quadpotential.py | 14 ++++++++++++- 6 files changed, 71 insertions(+), 25 deletions(-) diff --git a/pymc/blocking.py b/pymc/blocking.py index bd58cc38a3..76b7d62335 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -17,10 +17,9 @@ Classes for working with subsets of parameters. """ -import collections from functools import partial -from typing import Callable, Dict, Generic, Optional, TypeVar +from typing import Callable, Dict, Generic, NamedTuple, Optional, TypeVar import numpy as np @@ -32,7 +31,9 @@ # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for # each of the raveled variables. -RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info") +class RaveledVars(NamedTuple): + data: np.ndarray + point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...] class Compose(Generic[T]): diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 6ccf063b68..59eefdea39 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -16,8 +16,7 @@ import time from abc import abstractmethod -from collections import namedtuple -from typing import Optional +from typing import Any, NamedTuple, Optional import numpy as np @@ -29,20 +28,32 @@ from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.hmc import integration +from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential from pymc.tuning import guess_scaling from pymc.util import get_value_vars_from_user_vars logger = logging.getLogger("pymc") -HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats") -DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state, state_div") +class DivergenceInfo(NamedTuple): + message: str + exec_info: IntegrationError | None + state: State + state_div: State | None + + +class HMCStepData(NamedTuple): + end: State + accept_stat: int + divergence_info: DivergenceInfo | None + stats: dict[str, Any] class BaseHMC(GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" + integrator: integration.CpuLeapfrogIntegrator default_blocked = True def __init__( @@ -138,13 +149,13 @@ def __init__( self._num_divs_sample = 0 @abstractmethod - def _hamiltonian_step(self, start, p0, step_size): + def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: """Compute one Hamiltonian trajectory and return the next state. Subclasses must overwrite this abstract method and return an `HMCStepData` object. """ - def astep(self, q0): + def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]: """Perform a single HMC iteration.""" perf_start = time.perf_counter() process_start = time.process_time() @@ -154,6 +165,7 @@ def astep(self, q0): start = self.integrator.compute_state(q0, p0) + warning: Optional[SamplerWarning] = None if not np.isfinite(start.energy): model = self._model check_test_point_dict = model.point_logps() @@ -188,7 +200,6 @@ def astep(self, q0): self.step_adapt.update(hmc_step.accept_stat, adapt_step) self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) - warning: Optional[SamplerWarning] = None if hmc_step.divergence_info: info = hmc_step.divergence_info point = None @@ -221,7 +232,7 @@ def astep(self, q0): self.iter_count += 1 - stats = { + stats: dict[str, Any] = { "tune": self.tune, "diverging": bool(hmc_step.divergence_info), "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 4ed192ac99..c9eebf67b1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np from pymc.stats.convergence import SamplerWarning @@ -119,7 +121,7 @@ def __init__(self, vars=None, path_length=2.0, max_steps=1024, **kwargs): self.path_length = path_length self.max_steps = max_steps - def _hamiltonian_step(self, start, p0, step_size): + def _hamiltonian_step(self, start, p0, step_size: float) -> HMCStepData: n_steps = max(1, int(self.path_length / step_size)) n_steps = min(self.max_steps, n_steps) @@ -156,7 +158,7 @@ def _hamiltonian_step(self, start, p0, step_size): end = state accepted = True - stats = { + stats: dict[str, Any] = { "path_length": self.path_length, "n_steps": n_steps, "accept": accept_stat, diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index f631777010..67c9915eb1 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import Any, NamedTuple import numpy as np from scipy import linalg from pymc.blocking import RaveledVars +from pymc.step_methods.hmc.quadpotential import QuadPotential -State = namedtuple("State", "q, p, v, q_grad, energy, model_logp, index_in_trajectory") + +class State(NamedTuple): + q: RaveledVars + p: RaveledVars + v: np.ndarray + q_grad: Any + energy: Any + model_logp: Any + index_in_trajectory: int class IntegrationError(RuntimeError): @@ -28,7 +37,7 @@ class IntegrationError(RuntimeError): class CpuLeapfrogIntegrator: - def __init__(self, potential, logp_dlogp_func): + def __init__(self, potential: QuadPotential, logp_dlogp_func): """Leapfrog integrator using CPU.""" self._potential = potential self._logp_dlogp_func = logp_dlogp_func @@ -39,14 +48,14 @@ def __init__(self, potential, logp_dlogp_func): "don't match." % (self._potential.dtype, self._dtype) ) - def compute_state(self, q, p): + def compute_state(self, q: RaveledVars, p: RaveledVars): """Compute Hamiltonian functions using a position and momentum.""" if q.data.dtype != self._dtype or p.data.dtype != self._dtype: raise ValueError("Invalid dtype. Must be %s" % self._dtype) logp, dlogp = self._logp_dlogp_func(q) - v = self._potential.velocity(p.data) + v = self._potential.velocity(p.data, out=None) kinetic = self._potential.energy(p.data, velocity=v) energy = kinetic - logp return State(q, p, v, dlogp, energy, logp, 0) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 0eed003da1..282c002ade 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -20,8 +20,9 @@ from pymc.math import logbern from pymc.stats.convergence import SamplerWarning from pymc.step_methods.arraystep import Competence +from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData -from pymc.step_methods.hmc.integration import IntegrationError +from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.vartypes import continuous_types __all__ = ["NUTS"] @@ -227,7 +228,14 @@ def competence(var, has_grad): class _Tree: - def __init__(self, ndim, integrator, start, step_size, Emax): + def __init__( + self, + ndim, + integrator: integration.CpuLeapfrogIntegrator, + start: State, + step_size: float, + Emax: float, + ): """Binary tree from the NUTS algorithm. Parameters @@ -315,16 +323,19 @@ def extend(self, direction): return diverging, turning - def _single_step(self, left, epsilon): + def _single_step(self, left: State, epsilon: float): """Perform a leapfrog step and handle error cases.""" + right: State | None + error: IntegrationError | None + error_msg: str | None try: - # `State` type right = self.integrator.step(epsilon, left) except IntegrationError as err: error_msg = str(err) error = err right = None else: + assert right is not None # since there was no IntegrationError # h - H0 energy_change = right.energy - self.start_energy if np.isnan(energy_change): @@ -354,8 +365,8 @@ def _single_step(self, left, epsilon): finally: self.n_proposals += 1 tree = Subtree(None, None, None, None, -np.inf) - divergance_info = DivergenceInfo(error_msg, error, left, right) - return tree, divergance_info, False + divergence_info = DivergenceInfo(error_msg, error, left, right) + return tree, divergence_info, False def _build_subtree(self, left, depth, epsilon): if depth == 0: diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index d429b56b21..f5b55fe119 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -14,6 +14,8 @@ import warnings +from typing import overload + import aesara import numpy as np import scipy.linalg @@ -94,7 +96,17 @@ def __str__(self): class QuadPotential: - def velocity(self, x, out=None): + dtype: np.dtype + + @overload + def velocity(self, x: np.ndarray, out: None) -> np.ndarray: + ... + + @overload + def velocity(self, x: np.ndarray, out: np.ndarray) -> None: + ... + + def velocity(self, x: np.ndarray, out: np.ndarray | None = None) -> np.ndarray | None: """Compute the current velocity at a position in parameter space.""" raise NotImplementedError("Abstract method") From 165b4cac5fbda1738da5ccaf62bd0e272ac7364e Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sat, 12 Nov 2022 23:55:14 +0100 Subject: [PATCH 09/17] Unpin mypy --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements-dev.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 5ceef53dfd..d6a1895a34 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -38,5 +38,5 @@ dependencies: - watermark - polyagamma - sphinx-remove-toctrees -- mypy=0.982 +- mypy - types-cachetools diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 19fcd7ecd4..2dea663920 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -27,5 +27,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.982 +- mypy - types-cachetools diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index c5aea815e5..6a7012eea7 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -35,5 +35,5 @@ dependencies: - sphinx>=1.5 - watermark - sphinx-remove-toctrees -- mypy=0.982 +- mypy - types-cachetools diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index b27c453284..a510c4d4aa 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -28,5 +28,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.982 +- mypy - types-cachetools diff --git a/requirements-dev.txt b/requirements-dev.txt index 09819e6317..9475a128f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,7 @@ fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 jupyter-sphinx -mypy==0.982 +mypy myst-nb numpy>=1.15.0 numpydoc From 363b4c8e90710047979e2455cc7c6879c1710297 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 00:51:27 +0100 Subject: [PATCH 10/17] Import __future__.annotations to fix "| None" --- pymc/blocking.py | 7 ++++--- pymc/step_methods/hmc/base_hmc.py | 6 ++++-- pymc/step_methods/hmc/nuts.py | 2 ++ pymc/step_methods/hmc/quadpotential.py | 2 ++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pymc/blocking.py b/pymc/blocking.py index 76b7d62335..15e249d51e 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -17,9 +17,10 @@ Classes for working with subsets of parameters. """ +from __future__ import annotations from functools import partial -from typing import Callable, Dict, Generic, NamedTuple, Optional, TypeVar +from typing import Callable, Dict, Generic, NamedTuple, TypeVar import numpy as np @@ -70,7 +71,7 @@ def map(var_dict: PointType) -> RaveledVars: @staticmethod def rmap( array: RaveledVars, - start_point: Optional[PointType] = None, + start_point: PointType | None = None, ) -> PointType: """Map 1D concatenated array to a dictionary of variables in their original spaces. @@ -101,7 +102,7 @@ def rmap( @classmethod def mapf( - cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None + cls, f: Callable[[PointType], T], start_point: PointType | None = None ) -> Callable[[RaveledVars], T]: """Create a callable that first maps back to ``dict`` inputs and then applies a function. diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 59eefdea39..0eb6fd77b1 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import time from abc import abstractmethod -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import numpy as np @@ -165,7 +167,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]: start = self.integrator.compute_state(q0, p0) - warning: Optional[SamplerWarning] = None + warning: SamplerWarning | None = None if not np.isfinite(start.energy): model = self._model check_test_point_dict = model.point_logps() diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 282c002ade..393532eb44 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import namedtuple import numpy as np diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index f5b55fe119..f79f0e4302 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import warnings from typing import overload From 814acc5b4066adc2ecfb4ea0ceee721c25d9b2a0 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 18:38:43 +0100 Subject: [PATCH 11/17] Update pymc/step_methods/hmc/integration.py Co-authored-by: Adrian Seyboldt --- pymc/step_methods/hmc/integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 67c9915eb1..5ccdca7e4d 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -26,9 +26,9 @@ class State(NamedTuple): q: RaveledVars p: RaveledVars v: np.ndarray - q_grad: Any - energy: Any - model_logp: Any + q_grad: np.ndarray + energy: float + model_logp: float index_in_trajectory: int From 99346dd451054e130ad616b6243427e3b64ec4d6 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 18:45:48 +0100 Subject: [PATCH 12/17] Add __future__.annotations to hmc.py --- pymc/step_methods/hmc/hmc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index c9eebf67b1..804fadc304 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any import numpy as np From bd939e0a573b3386c64726bd473ad1b3cfd24fa7 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 18:48:20 +0100 Subject: [PATCH 13/17] Remove unused Any import --- pymc/step_methods/hmc/integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 5ccdca7e4d..c2dc7cf191 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, NamedTuple +from typing import NamedTuple import numpy as np From 8b127a6de1fc796535c9e93172d089565f731080 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 18:57:48 +0100 Subject: [PATCH 14/17] Don't cast float to np.array --- pymc/step_methods/hmc/nuts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 393532eb44..114a523479 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -257,7 +257,7 @@ def __init__( self.start = start self.step_size = step_size self.Emax = Emax - self.start_energy = np.array(start.energy) + self.start_energy = start.energy self.left = self.right = start self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) From 9becee3439710c18b946f4f680961c92600c05d7 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Sun, 13 Nov 2022 19:00:46 +0100 Subject: [PATCH 15/17] Replace 0 with 0.0 for float zeros --- pymc/step_methods/hmc/nuts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 114a523479..744e4ded39 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -262,12 +262,12 @@ def __init__( self.left = self.right = start self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) self.depth = 0 - self.log_size = 0 + self.log_size = 0.0 self.log_accept_sum = -np.inf self.mean_tree_accept = 0.0 self.n_proposals = 0 self.p_sum = start.p.data.copy() - self.max_energy_change = 0 + self.max_energy_change = 0.0 def extend(self, direction): """Double the treesize by extending the tree in the given direction. From 2582846bc334a707d844e3ddad48aae487db37ba Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 14 Nov 2022 12:57:24 +0100 Subject: [PATCH 16/17] Update pymc/step_methods/hmc/nuts.py Co-authored-by: Michael Osthege --- pymc/step_methods/hmc/nuts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 744e4ded39..f1251d5c55 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -232,7 +232,7 @@ def competence(var, has_grad): class _Tree: def __init__( self, - ndim, + ndim: int, integrator: integration.CpuLeapfrogIntegrator, start: State, step_size: float, From dc3abb2a4b45c7fad2acce6f82efc000a7930990 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 14 Nov 2022 15:33:35 +0100 Subject: [PATCH 17/17] Repin mypy with v0.990 --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements-dev.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index d6a1895a34..fd56a26372 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -38,5 +38,5 @@ dependencies: - watermark - polyagamma - sphinx-remove-toctrees -- mypy +- mypy=0.990 - types-cachetools diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 2dea663920..58ad3be53a 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -27,5 +27,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy +- mypy=0.990 - types-cachetools diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 6a7012eea7..dc3a80fcdd 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -35,5 +35,5 @@ dependencies: - sphinx>=1.5 - watermark - sphinx-remove-toctrees -- mypy +- mypy=0.990 - types-cachetools diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index a510c4d4aa..b1515cb7b7 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -28,5 +28,5 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy +- mypy=0.990 - types-cachetools diff --git a/requirements-dev.txt b/requirements-dev.txt index 9475a128f4..785b368998 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,7 @@ fastprogress>=0.2.0 h5py>=2.7 ipython>=7.16 jupyter-sphinx -mypy +mypy==0.990 myst-nb numpy>=1.15.0 numpydoc