Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
check_bounds = False in pm.Model()
"""
# at.all does not accept True/False, but accepts np.array(True)/np.array(False)
conditions = [
conditions_ = [
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
]
all_true_scalar = at.all([at.all(cond) for cond in conditions])
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
return CheckParameterValue(msg)(logp, all_true_scalar)


Expand Down
28 changes: 13 additions & 15 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,21 +988,19 @@ def __new__(
ndim_supp=ndim_supp,
**kwargs,
)
else:
return _CustomDist(
name,
*dist_params,
class_name=name,
random=random,
logp=logp,
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
**kwargs,
)
return super().__new__(cls, name, *args, **kwargs)
return _CustomDist(
name,
*dist_params,
class_name=name,
random=random,
logp=logp,
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
**kwargs,
)

@classmethod
def dist(
Expand Down
2 changes: 1 addition & 1 deletion pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

# Avoid circular dependency when importing modelcontext
from pymc.distributions.distribution import Distribution
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, walk_model

_ = Distribution # keep both pylint and black happy
from pymc.model import modelcontext

JITTER_DEFAULT = 1e-6

Expand Down
6 changes: 3 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,16 +961,16 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
if use_jacobian:
assert len(values) == len(logprobs) == len(op.transforms)
logprobs_jac = []
for value, transform, logprob in zip(values, op.transforms, logprobs):
for value, transform, logp in zip(values, op.transforms, logprobs):
if transform is None:
logprobs_jac.append(logprob)
logprobs_jac.append(logp)
continue
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
if value.name:
jacobian.name = f"{value.name}_jacobian"
logprobs_jac.append(logprob + jacobian)
logprobs_jac.append(logp + jacobian)
logprobs = logprobs_jac

return logprobs
Expand Down
2 changes: 1 addition & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
raise a ``TypeError`` instead of returning ``None``."""
try:
candidate: Optional[T] = cls.get_contexts()[-1]
except IndexError as e:
except IndexError:
# Calling code expects to get a TypeError if the entity
# is unfound, and there's too much to fix.
if error_if_none:
Expand Down
6 changes: 3 additions & 3 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _expand(x):
return []

parents = {
get_var_name(x)
VarName(get_var_name(x))
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
# Only consider nodes that are in the named model variables.
if x.name and x.name in self._all_var_names
Expand Down Expand Up @@ -109,7 +109,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
selected_ancestors.add(self.model.rvs_to_values[var])

# ordering of self._all_var_names is important
return [var.name for var in selected_ancestors]
return [VarName(var.name) for var in selected_ancestors]

def make_compute_graph(
self, var_names: Optional[Iterable[VarName]] = None
Expand Down Expand Up @@ -230,7 +230,7 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
plate_label = " x ".join(dim_labels)
else:
# The RV has no `dims` information.
dim_labels = map(str, shape)
dim_labels = [str(x) for x in shape]
plate_label = " x ".join(map(str, shape))
plates[plate_label].add(var_name)

Expand Down
29 changes: 12 additions & 17 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,10 @@
import re
import sys

from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from pytensor.tensor.random.type import RandomType

from pymc.initial_point import StartDict
from pymc.sampling.mcmc import _init_jitter

xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)

from datetime import datetime

import arviz as az
import jax
import numpy as np
Expand All @@ -43,18 +33,25 @@
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.shape import SpecifyShape

from pymc import Model, modelcontext
from pymc.backends.arviz import find_constants, find_observations
from pymc.initial_point import StartDict
from pymc.logprob.utils import CheckParameterValue
from pymc.sampling.mcmc import _init_jitter
from pymc.util import (
RandomSeed,
RandomState,
_get_seeds_per_chain,
get_default_varnames,
)

xla_flags_env = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)

__all__ = (
"get_jaxified_graph",
"get_jaxified_logp",
Expand Down Expand Up @@ -111,7 +108,7 @@ def get_jaxified_graph(
) -> List[TensorVariable]:
"""Compile an PyTensor graph into an optimized JAX function"""

graph = _replace_shared_variables(outputs)
graph = _replace_shared_variables(outputs) if outputs is not None else None

fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
Expand Down Expand Up @@ -254,12 +251,10 @@ def _get_batched_jittered_initial_points(
jitter=jitter,
jitter_max_retries=jitter_max_retries,
)
initial_points = [list(initial_point.values()) for initial_point in initial_points]
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
initial_points = initial_points[0]
else:
initial_points = [np.stack(init_state) for init_state in zip(*initial_points)]
return initial_points
return initial_points_values[0]
return [np.stack(init_state) for init_state in zip(*initial_points_values)]


def _update_coords_and_dims(
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def step(self, epsilon, state):
return self._step(epsilon, state)
except linalg.LinAlgError as err:
msg = "LinAlgError during leapfrog step."
raise IntegrationError(msg)
raise IntegrationError(msg) from err
except ValueError as err:
# Raised by many scipy.linalg functions
scipy_msg = "array must not contain infs or nans"
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def _hamiltonian_step(self, start, p0, step_size):
def competence(var, has_grad):
"""Check how appropriate this class is for sampling a random variable."""

dist = getattr(var.owner, "op", None)
if var.dtype in continuous_types and has_grad:
return Competence.PREFERRED
return Competence.INCOMPATIBLE
Expand Down
4 changes: 2 additions & 2 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
import pytensor
Expand Down Expand Up @@ -331,7 +331,7 @@ def sample_approx(approx, draws=100, include_transformed=True):
class SingleGroupApproximation(Approximation):
"""Base class for Single Group Approximation"""

_group_class = None
_group_class: Optional[type] = None

def __init__(self, *args, **kwargs):
groups = [self._group_class(None, *args, **kwargs)]
Expand Down
16 changes: 9 additions & 7 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
import itertools
import warnings

from typing import Any

import numpy as np
import pytensor
import pytensor.tensor as at
Expand Down Expand Up @@ -673,11 +675,11 @@ class Group(WithMemoization):
initial_dist_map = 0.0

# for handy access using class methods
__param_spec__ = dict()
__param_spec__: dict = dict()
short_name = ""
alias_names = frozenset()
__param_registry = dict()
__name_registry = dict()
alias_names: frozenset[str] = frozenset()
__param_registry: dict[frozenset, Any] = dict()
__name_registry: dict[str, Any] = dict()

@classmethod
def register(cls, sbcls):
Expand Down Expand Up @@ -1552,11 +1554,11 @@ def sample(
finally:
trace.close()

trace = MultiTrace([trace])
multi_trace = MultiTrace([trace])
if not return_inferencedata:
return trace
return multi_trace
else:
return pm.to_inference_data(trace, model=self.model, **kwargs)
return pm.to_inference_data(multi_trace, model=self.model, **kwargs)

@property
def ndim(self):
Expand Down
1 change: 0 additions & 1 deletion scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
pymc/printing.py
pymc/pytensorf.py
pymc/sampling/jax.py
pymc/variational/approximations.py
pymc/variational/opvi.py
"""

Expand Down