@@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector):
129129
130130
131131@_logprob .register (MeasurableMakeVector )
132- def logprob_make_vector (op , values , * base_vars , ** kwargs ):
132+ def logprob_make_vector (op , values , * base_rvs , ** kwargs ):
133133 """Compute the log-likelihood graph for a `MeasurableMakeVector`."""
134+ # TODO: Sort out this circular dependency issue
135+ from pymc .pytensorf import replace_rvs_by_values
136+
134137 (value ,) = values
135138
136- return at .stack ([logprob (base_var , value [i ]) for i , base_var in enumerate (base_vars )])
139+ base_rvs_to_values = {base_rv : value [i ] for i , base_rv in enumerate (base_rvs )}
140+ for i , (base_rv , value ) in enumerate (base_rvs_to_values .items ()):
141+ base_rv .name = f"base_rv[{ i } ]"
142+ value .name = f"value[{ i } ]"
143+
144+ logps = [logprob (base_rv , value ) for base_rv , value in base_rvs_to_values .items ()]
145+
146+ # If the stacked variables depend on each other, we have to replace them by the respective values
147+ logps = replace_rvs_by_values (logps , rvs_to_values = base_rvs_to_values )
148+
149+ return at .stack (logps )
137150
138151
139152class MeasurableJoin (Join ):
@@ -144,27 +157,28 @@ class MeasurableJoin(Join):
144157
145158
146159@_logprob .register (MeasurableJoin )
147- def logprob_join (op , values , axis , * base_vars , ** kwargs ):
160+ def logprob_join (op , values , axis , * base_rvs , ** kwargs ):
148161 """Compute the log-likelihood graph for a `Join`."""
149- (value ,) = values
162+ # TODO: Find better way to avoid circular dependency
163+ from pymc .pytensorf import constant_fold , replace_rvs_by_values
150164
151- base_var_shapes = [ base_var . shape [ axis ] for base_var in base_vars ]
165+ ( value ,) = values
152166
153- # TODO: Find better way to avoid circular dependency
154- from pymc .pytensorf import constant_fold
167+ base_rv_shapes = [base_var .shape [axis ] for base_var in base_rvs ]
155168
156169 # We don't need the graph to be constant, just to have RandomVariables removed
157- base_var_shapes = constant_fold (base_var_shapes , raise_not_constant = False )
170+ base_rv_shapes = constant_fold (base_rv_shapes , raise_not_constant = False )
158171
159172 split_values = at .split (
160173 value ,
161- splits_size = base_var_shapes ,
162- n_splits = len (base_vars ),
174+ splits_size = base_rv_shapes ,
175+ n_splits = len (base_rvs ),
163176 axis = axis ,
164177 )
165178
179+ base_rvs_to_split_values = {base_rv : value for base_rv , value in zip (base_rvs , split_values )}
166180 logps = [
167- logprob (base_var , split_value ) for base_var , split_value in zip ( base_vars , split_values )
181+ logprob (base_var , split_value ) for base_var , split_value in base_rvs_to_split_values . items ( )
168182 ]
169183
170184 if len ({logp .ndim for logp in logps }) != 1 :
@@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs):
173187 "joining univariate and multivariate distributions" ,
174188 )
175189
190+ # If the stacked variables depend on each other, we have to replace them by the respective values
191+ logps = replace_rvs_by_values (logps , rvs_to_values = base_rvs_to_split_values )
192+
176193 base_vars_ndim_supp = split_values [0 ].ndim - logps [0 ].ndim
177194 join_logprob = at .concatenate (
178- [
179- at .atleast_1d (logprob (base_var , split_value ))
180- for base_var , split_value in zip (base_vars , split_values )
181- ],
195+ [at .atleast_1d (logp ) for logp in logps ],
182196 axis = axis - base_vars_ndim_supp ,
183197 )
184198
@@ -199,6 +213,8 @@ def find_measurable_stacks(
199213 if rv_map_feature is None :
200214 return None # pragma: no cover
201215
216+ rvs_to_values = rv_map_feature .rv_values
217+
202218 stack_out = node .outputs [0 ]
203219
204220 is_join = isinstance (node .op , Join )
@@ -211,18 +227,40 @@ def find_measurable_stacks(
211227 if not all (
212228 base_var .owner
213229 and isinstance (base_var .owner .op , MeasurableVariable )
214- and base_var not in rv_map_feature . rv_values
230+ and base_var not in rvs_to_values
215231 for base_var in base_vars
216232 ):
217233 return None # pragma: no cover
218234
219235 # Make base_vars unmeasurable
220- base_vars = [assign_custom_measurable_outputs (base_var .owner ) for base_var in base_vars ]
236+ base_to_unmeasurable_vars = {
237+ base_var : assign_custom_measurable_outputs (base_var .owner ).outputs [
238+ base_var .owner .outputs .index (base_var )
239+ ]
240+ for base_var in base_vars
241+ }
242+
243+ def replacement_fn (var , replacements ):
244+ if var in base_to_unmeasurable_vars :
245+ replacements [var ] = base_to_unmeasurable_vars [var ]
246+ # We don't want to clone valued nodes. Assigning a var to itself in the
247+ # replacements prevents this
248+ elif var in rvs_to_values :
249+ replacements [var ] = var
250+
251+ return []
252+
253+ # TODO: Fix this import circularity!
254+ from pymc .pytensorf import _replace_rvs_in_graphs
255+
256+ unmeasurable_base_vars , _ = _replace_rvs_in_graphs (
257+ graphs = base_vars , replacement_fn = replacement_fn
258+ )
221259
222260 if is_join :
223- measurable_stack = MeasurableJoin ()(axis , * base_vars )
261+ measurable_stack = MeasurableJoin ()(axis , * unmeasurable_base_vars )
224262 else :
225- measurable_stack = MeasurableMakeVector (node .op .dtype )(* base_vars )
263+ measurable_stack = MeasurableMakeVector (node .op .dtype )(* unmeasurable_base_vars )
226264
227265 measurable_stack .name = stack_out .name
228266
0 commit comments