|
69 | 69 | from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType |
70 | 70 | from pytensor.tensor.var import TensorVariable |
71 | 71 |
|
72 | | -from pymc.logprob.abstract import ( |
73 | | - MeasurableVariable, |
74 | | - _logprob, |
75 | | - assign_custom_measurable_outputs, |
76 | | - logprob, |
77 | | -) |
| 72 | +from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob |
78 | 73 | from pymc.logprob.rewriting import ( |
79 | 74 | local_lift_DiracDelta, |
80 | 75 | logprob_rewrites_db, |
81 | 76 | subtensor_ops, |
82 | 77 | ) |
83 | 78 | from pymc.logprob.tensor import naive_bcast_rv_lift |
| 79 | +from pymc.logprob.utils import ignore_logprob |
84 | 80 |
|
85 | 81 |
|
86 | 82 | def is_newaxis(x): |
@@ -328,9 +324,7 @@ def mixture_replace(fgraph, node): |
328 | 324 | # We create custom types for the mixture components and assign them |
329 | 325 | # null `get_measurable_outputs` dispatches so that they aren't |
330 | 326 | # erroneously encountered in places like `factorized_joint_logprob`. |
331 | | - new_node = assign_custom_measurable_outputs(component_rv.owner) |
332 | | - out_idx = component_rv.owner.outputs.index(component_rv) |
333 | | - new_comp_rv = new_node.outputs[out_idx] |
| 327 | + new_comp_rv = ignore_logprob(component_rv) |
334 | 328 | new_mixture_rvs.append(new_comp_rv) |
335 | 329 |
|
336 | 330 | # Replace this sub-graph with a `MixtureRV` |
@@ -379,9 +373,7 @@ def switch_mixture_replace(fgraph, node): |
379 | 373 | and component_rv not in rv_map_feature.rv_values |
380 | 374 | ): |
381 | 375 | return None |
382 | | - new_node = assign_custom_measurable_outputs(component_rv.owner) |
383 | | - out_idx = component_rv.owner.outputs.index(component_rv) |
384 | | - new_comp_rv = new_node.outputs[out_idx] |
| 376 | + new_comp_rv = ignore_logprob(component_rv) |
385 | 377 | mixture_rvs.append(new_comp_rv) |
386 | 378 |
|
387 | 379 | mix_op = MixtureRV( |
|
0 commit comments