@@ -74,7 +74,7 @@ def get_jaxified_logp(model: Model) -> Callable:
7474
7575 logpt = replace_shared_variables ([model .logpt ()])[0 ]
7676
77- logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = False )
77+ logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = True )
7878 optimize_graph (logpt_fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
7979
8080 # We now jaxify the optimized fgraph
@@ -123,7 +123,7 @@ def _get_log_likelihood(model, samples):
123123 data = {}
124124 for v in model .observed_RVs :
125125 logp_v = replace_shared_variables ([model .logpt (v , sum = False )[0 ]])
126- fgraph = FunctionGraph (model .value_vars , logp_v , clone = False )
126+ fgraph = FunctionGraph (model .value_vars , logp_v , clone = True )
127127 optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
128128 jax_fn = jax_funcify (fgraph )
129129 result = jax .jit (jax .vmap (jax .vmap (jax_fn )))(* samples )[0 ]
@@ -229,7 +229,7 @@ def sample_numpyro_nuts(
229229 print ("Transforming variables..." , file = sys .stdout )
230230 mcmc_samples = {}
231231 for v in vars_to_sample :
232- fgraph = FunctionGraph (model .value_vars , [v ], clone = False )
232+ fgraph = FunctionGraph (model .value_vars , [v ], clone = True )
233233 optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
234234 jax_fn = jax_funcify (fgraph )
235235 result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
0 commit comments