From 42d3aa7b601be806aaf1398f1fb16db1631eca06 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 13 Feb 2023 15:18:36 -0500 Subject: [PATCH 1/7] Make isort happy with imports order --- pymc/gp/util.py | 2 +- pymc/sampling/jax.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pymc/gp/util.py b/pymc/gp/util.py index a677a25593..f2ae803895 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -28,10 +28,10 @@ # Avoid circular dependency when importing modelcontext from pymc.distributions.distribution import Distribution +from pymc.model import modelcontext from pymc.pytensorf import compile_pymc, walk_model _ = Distribution # keep both pylint and black happy -from pymc.model import modelcontext JITTER_DEFAULT = 1e-6 diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 584c1d91dc..ba972e4cbf 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -15,20 +15,10 @@ import re import sys +from datetime import datetime from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Union -from pytensor.tensor.random.type import RandomType - -from pymc.initial_point import StartDict -from pymc.sampling.mcmc import _init_jitter - -xla_flags = os.getenv("XLA_FLAGS", "") -xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() -os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) - -from datetime import datetime - import arviz as az import jax import numpy as np @@ -43,11 +33,14 @@ from pytensor.link.jax.dispatch import jax_funcify from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable +from pytensor.tensor.random.type import RandomType from pytensor.tensor.shape import SpecifyShape from pymc import Model, modelcontext from pymc.backends.arviz import find_constants, find_observations +from pymc.initial_point import StartDict from pymc.logprob.utils import CheckParameterValue +from pymc.sampling.mcmc import _init_jitter from pymc.util import ( RandomSeed, RandomState, @@ -55,6 +48,10 @@ get_default_varnames, ) +xla_flags = os.getenv("XLA_FLAGS", "") +xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() +os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) + __all__ = ( "get_jaxified_graph", "get_jaxified_logp", From b5617b4a52475bfdb12d6421670c74ef604f9a30 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 13 Feb 2023 15:36:27 -0500 Subject: [PATCH 2/7] Avoid shadowing the logprob fn with a local name --- pymc/logprob/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index bb9c3ab057..5f3f320be5 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -961,16 +961,16 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): if use_jacobian: assert len(values) == len(logprobs) == len(op.transforms) logprobs_jac = [] - for value, transform, logprob in zip(values, op.transforms, logprobs): + for value, transform, logp in zip(values, op.transforms, logprobs): if transform is None: - logprobs_jac.append(logprob) + logprobs_jac.append(logp) continue assert isinstance(value.owner.op, TransformedVariable) original_forward_value = value.owner.inputs[1] jacobian = transform.log_jac_det(original_forward_value, *inputs).copy() if value.name: jacobian.name = f"{value.name}_jacobian" - logprobs_jac.append(logprob + jacobian) + logprobs_jac.append(logp + jacobian) logprobs = logprobs_jac return logprobs From c8361746036c564b8e5a6028451fc0be2f193b83 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 13 Feb 2023 15:39:50 -0500 Subject: [PATCH 3/7] Deal with unused error names --- pymc/model.py | 2 +- pymc/step_methods/hmc/integration.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 104d810ff8..f10c3bf22d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -195,7 +195,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]: raise a ``TypeError`` instead of returning ``None``.""" try: candidate: Optional[T] = cls.get_contexts()[-1] - except IndexError as e: + except IndexError: # Calling code expects to get a TypeError if the entity # is unfound, and there's too much to fix. if error_if_none: diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index ab02fc83dd..5e0bdb8ee0 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -82,7 +82,7 @@ def step(self, epsilon, state): return self._step(epsilon, state) except linalg.LinAlgError as err: msg = "LinAlgError during leapfrog step." - raise IntegrationError(msg) + raise IntegrationError(msg) from err except ValueError as err: # Raised by many scipy.linalg functions scipy_msg = "array must not contain infs or nans" From 7c8cca0bff5f484e233a56512732c0e660ad4fa9 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 23 Feb 2023 22:18:54 -0500 Subject: [PATCH 4/7] Remove unreachable return --- pymc/distributions/distribution.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5ece0b1bb6..9a47cf1390 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -988,21 +988,19 @@ def __new__( ndim_supp=ndim_supp, **kwargs, ) - else: - return _CustomDist( - name, - *dist_params, - class_name=name, - random=random, - logp=logp, - logcdf=logcdf, - moment=moment, - ndim_supp=ndim_supp, - ndims_params=ndims_params, - dtype=dtype, - **kwargs, - ) - return super().__new__(cls, name, *args, **kwargs) + return _CustomDist( + name, + *dist_params, + class_name=name, + random=random, + logp=logp, + logcdf=logcdf, + moment=moment, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + dtype=dtype, + **kwargs, + ) @classmethod def dist( From b336ee3b85cda328aca54276043316c15a2a0cc4 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 23 Feb 2023 22:20:30 -0500 Subject: [PATCH 5/7] Remove unread variable --- pymc/step_methods/hmc/nuts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 9d377205b3..61dc56a8a6 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -210,7 +210,6 @@ def _hamiltonian_step(self, start, p0, step_size): def competence(var, has_grad): """Check how appropriate this class is for sampling a random variable.""" - dist = getattr(var.owner, "op", None) if var.dtype in continuous_types and has_grad: return Competence.PREFERRED return Competence.INCOMPATIBLE From ba64afa4c82a2d99fd4a4f0c7dff0f8a566487be Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 23 Feb 2023 22:39:04 -0500 Subject: [PATCH 6/7] Fix type checking for variational/approximations --- pymc/variational/approximations.py | 4 ++-- scripts/run_mypy.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 144ca09b71..ec64c13fb8 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional import numpy as np import pytensor @@ -331,7 +331,7 @@ def sample_approx(approx, draws=100, include_transformed=True): class SingleGroupApproximation(Approximation): """Base class for Single Group Approximation""" - _group_class = None + _group_class: Optional[type] = None def __init__(self, *args, **kwargs): groups = [self._group_class(None, *args, **kwargs)] diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 717881f012..3774ad8333 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -42,7 +42,6 @@ pymc/printing.py pymc/pytensorf.py pymc/sampling/jax.py -pymc/variational/approximations.py pymc/variational/opvi.py """ From 96b7160b2b4cb0057947f8c01079f9fd54c2dda7 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 23 Feb 2023 23:34:15 -0500 Subject: [PATCH 7/7] A few typing fixes --- pymc/distributions/dist_math.py | 4 ++-- pymc/model_graph.py | 6 +++--- pymc/sampling/jax.py | 14 ++++++-------- pymc/variational/opvi.py | 16 +++++++++------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index b075238788..fbdea97440 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -61,10 +61,10 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = check_bounds = False in pm.Model() """ # at.all does not accept True/False, but accepts np.array(True)/np.array(False) - conditions = [ + conditions_ = [ cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions ] - all_true_scalar = at.all([at.all(cond) for cond in conditions]) + all_true_scalar = at.all([at.all(cond) for cond in conditions_]) return CheckParameterValue(msg)(logp, all_true_scalar) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 03de550127..43e7d1f728 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -73,7 +73,7 @@ def _expand(x): return [] parents = { - get_var_name(x) + VarName(get_var_name(x)) for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand) # Only consider nodes that are in the named model variables. if x.name and x.name in self._all_var_names @@ -109,7 +109,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va selected_ancestors.add(self.model.rvs_to_values[var]) # ordering of self._all_var_names is important - return [var.name for var in selected_ancestors] + return [VarName(var.name) for var in selected_ancestors] def make_compute_graph( self, var_names: Optional[Iterable[VarName]] = None @@ -230,7 +230,7 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, plate_label = " x ".join(dim_labels) else: # The RV has no `dims` information. - dim_labels = map(str, shape) + dim_labels = [str(x) for x in shape] plate_label = " x ".join(map(str, shape)) plates[plate_label].add(var_name) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index ba972e4cbf..0cdc8afd0c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -48,8 +48,8 @@ get_default_varnames, ) -xla_flags = os.getenv("XLA_FLAGS", "") -xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() +xla_flags_env = os.getenv("XLA_FLAGS", "") +xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split() os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) __all__ = ( @@ -108,7 +108,7 @@ def get_jaxified_graph( ) -> List[TensorVariable]: """Compile an PyTensor graph into an optimized JAX function""" - graph = _replace_shared_variables(outputs) + graph = _replace_shared_variables(outputs) if outputs is not None else None fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) # We need to add a Supervisor to the fgraph to be able to run the @@ -251,12 +251,10 @@ def _get_batched_jittered_initial_points( jitter=jitter, jitter_max_retries=jitter_max_retries, ) - initial_points = [list(initial_point.values()) for initial_point in initial_points] + initial_points_values = [list(initial_point.values()) for initial_point in initial_points] if chains == 1: - initial_points = initial_points[0] - else: - initial_points = [np.stack(init_state) for init_state in zip(*initial_points)] - return initial_points + return initial_points_values[0] + return [np.stack(init_state) for init_state in zip(*initial_points_values)] def _update_coords_and_dims( diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 8adcd4ca4e..a2e371c1ab 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -51,6 +51,8 @@ import itertools import warnings +from typing import Any + import numpy as np import pytensor import pytensor.tensor as at @@ -673,11 +675,11 @@ class Group(WithMemoization): initial_dist_map = 0.0 # for handy access using class methods - __param_spec__ = dict() + __param_spec__: dict = dict() short_name = "" - alias_names = frozenset() - __param_registry = dict() - __name_registry = dict() + alias_names: frozenset[str] = frozenset() + __param_registry: dict[frozenset, Any] = dict() + __name_registry: dict[str, Any] = dict() @classmethod def register(cls, sbcls): @@ -1552,11 +1554,11 @@ def sample( finally: trace.close() - trace = MultiTrace([trace]) + multi_trace = MultiTrace([trace]) if not return_inferencedata: - return trace + return multi_trace else: - return pm.to_inference_data(trace, model=self.model, **kwargs) + return pm.to_inference_data(multi_trace, model=self.model, **kwargs) @property def ndim(self):