Skip to content

Commit 96b7160

Browse files
committed
A few typing fixes
1 parent ba64afa commit 96b7160

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

pymc/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
6161
check_bounds = False in pm.Model()
6262
"""
6363
# at.all does not accept True/False, but accepts np.array(True)/np.array(False)
64-
conditions = [
64+
conditions_ = [
6565
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
6666
]
67-
all_true_scalar = at.all([at.all(cond) for cond in conditions])
67+
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
6868
return CheckParameterValue(msg)(logp, all_true_scalar)
6969

7070

pymc/model_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _expand(x):
7373
return []
7474

7575
parents = {
76-
get_var_name(x)
76+
VarName(get_var_name(x))
7777
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
7878
# Only consider nodes that are in the named model variables.
7979
if x.name and x.name in self._all_var_names
@@ -109,7 +109,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
109109
selected_ancestors.add(self.model.rvs_to_values[var])
110110

111111
# ordering of self._all_var_names is important
112-
return [var.name for var in selected_ancestors]
112+
return [VarName(var.name) for var in selected_ancestors]
113113

114114
def make_compute_graph(
115115
self, var_names: Optional[Iterable[VarName]] = None
@@ -230,7 +230,7 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
230230
plate_label = " x ".join(dim_labels)
231231
else:
232232
# The RV has no `dims` information.
233-
dim_labels = map(str, shape)
233+
dim_labels = [str(x) for x in shape]
234234
plate_label = " x ".join(map(str, shape))
235235
plates[plate_label].add(var_name)
236236

pymc/sampling/jax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
get_default_varnames,
4949
)
5050

51-
xla_flags = os.getenv("XLA_FLAGS", "")
52-
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
51+
xla_flags_env = os.getenv("XLA_FLAGS", "")
52+
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
5353
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
5454

5555
__all__ = (
@@ -108,7 +108,7 @@ def get_jaxified_graph(
108108
) -> List[TensorVariable]:
109109
"""Compile an PyTensor graph into an optimized JAX function"""
110110

111-
graph = _replace_shared_variables(outputs)
111+
graph = _replace_shared_variables(outputs) if outputs is not None else None
112112

113113
fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
114114
# We need to add a Supervisor to the fgraph to be able to run the
@@ -251,12 +251,10 @@ def _get_batched_jittered_initial_points(
251251
jitter=jitter,
252252
jitter_max_retries=jitter_max_retries,
253253
)
254-
initial_points = [list(initial_point.values()) for initial_point in initial_points]
254+
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
255255
if chains == 1:
256-
initial_points = initial_points[0]
257-
else:
258-
initial_points = [np.stack(init_state) for init_state in zip(*initial_points)]
259-
return initial_points
256+
return initial_points_values[0]
257+
return [np.stack(init_state) for init_state in zip(*initial_points_values)]
260258

261259

262260
def _update_coords_and_dims(

pymc/variational/opvi.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
import itertools
5252
import warnings
5353

54+
from typing import Any
55+
5456
import numpy as np
5557
import pytensor
5658
import pytensor.tensor as at
@@ -673,11 +675,11 @@ class Group(WithMemoization):
673675
initial_dist_map = 0.0
674676

675677
# for handy access using class methods
676-
__param_spec__ = dict()
678+
__param_spec__: dict = dict()
677679
short_name = ""
678-
alias_names = frozenset()
679-
__param_registry = dict()
680-
__name_registry = dict()
680+
alias_names: frozenset[str] = frozenset()
681+
__param_registry: dict[frozenset, Any] = dict()
682+
__name_registry: dict[str, Any] = dict()
681683

682684
@classmethod
683685
def register(cls, sbcls):
@@ -1552,11 +1554,11 @@ def sample(
15521554
finally:
15531555
trace.close()
15541556

1555-
trace = MultiTrace([trace])
1557+
multi_trace = MultiTrace([trace])
15561558
if not return_inferencedata:
1557-
return trace
1559+
return multi_trace
15581560
else:
1559-
return pm.to_inference_data(trace, model=self.model, **kwargs)
1561+
return pm.to_inference_data(multi_trace, model=self.model, **kwargs)
15601562

15611563
@property
15621564
def ndim(self):

0 commit comments

Comments
 (0)