Skip to content

Commit b67d128

Browse files
committed
Refactor utility to ignore the logprob of multiple variables while keeping their interdependencies intact
1 parent d14ecd5 commit b67d128

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

pymc/logprob/tensor.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
5454
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
55-
from pymc.logprob.utils import ignore_logprob
55+
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars
5656

5757

5858
@node_rewriter([BroadcastTo])
@@ -228,25 +228,7 @@ def find_measurable_stacks(
228228
):
229229
return None # pragma: no cover
230230

231-
# Make base_vars unmeasurable
232-
base_to_unmeasurable_vars = {base_var: ignore_logprob(base_var) for base_var in base_vars}
233-
234-
def replacement_fn(var, replacements):
235-
if var in base_to_unmeasurable_vars:
236-
replacements[var] = base_to_unmeasurable_vars[var]
237-
# We don't want to clone valued nodes. Assigning a var to itself in the
238-
# replacements prevents this
239-
elif var in rvs_to_values:
240-
replacements[var] = var
241-
242-
return []
243-
244-
# TODO: Fix this import circularity!
245-
from pymc.pytensorf import _replace_rvs_in_graphs
246-
247-
unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
248-
graphs=base_vars, replacement_fn=replacement_fn
249-
)
231+
unmeasurable_base_vars = ignore_logprob_multiple_vars(base_vars, rvs_to_values)
250232

251233
if is_join:
252234
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)

pymc/logprob/utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,17 @@
3737
import warnings
3838

3939
from copy import copy
40-
from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple
40+
from typing import (
41+
Callable,
42+
Dict,
43+
Generator,
44+
Iterable,
45+
List,
46+
Optional,
47+
Sequence,
48+
Set,
49+
Tuple,
50+
)
4151

4252
import numpy as np
4353

@@ -262,7 +272,7 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
262272
def ignore_logprob(rv: TensorVariable) -> TensorVariable:
263273
"""Return a duplicated variable that is ignored when creating logprob graphs
264274
265-
This is used in SymbolicDistributions that use other RVs as inputs but account
275+
This is used in by MeasurableRVs that use other RVs as inputs but account
266276
for their logp terms explicitly.
267277
268278
If the variable is already ignored, it is returned directly.
@@ -295,3 +305,32 @@ def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
295305
new_node.op = copy(new_node.op)
296306
new_node.op.__class__ = original_op_type
297307
return new_node.outputs[node.outputs.index(rv)]
308+
309+
310+
def ignore_logprob_multiple_vars(
311+
vars: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable]
312+
) -> List[TensorVariable]:
313+
"""Return duplicated variables that are ignored when creating logprob graphs.
314+
315+
This function keeps any interdependencies between variables intact, after
316+
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
317+
would not do this correctly.
318+
"""
319+
from pymc.pytensorf import _replace_rvs_in_graphs
320+
321+
measurable_vars_to_unmeasurable_vars = {
322+
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
323+
}
324+
325+
def replacement_fn(var, replacements):
326+
if var in measurable_vars_to_unmeasurable_vars:
327+
replacements[var] = measurable_vars_to_unmeasurable_vars[var]
328+
# We don't want to clone valued nodes. Assigning a var to itself in the
329+
# replacements prevents this
330+
elif var in rvs_to_values:
331+
replacements[var] = var
332+
333+
return []
334+
335+
unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
336+
return unmeasurable_vars

0 commit comments

Comments
 (0)