@@ -118,7 +118,8 @@ def create_energy_func(q):
118118 self .H , q = create_hamiltonian (vars , shared , model )
119119 self .compute_energy = create_energy_func (q )
120120
121- self .leapfrog1_dE = leapfrog1_dE (self .H , q , profile = profile )
121+ self .leapfrog1_dE = leapfrog1_dE (self .H , q , profile = profile ,
122+ mode = self .mode )
122123
123124 super (NUTS , self ).__init__ (vars , shared , ** kwargs )
124125
@@ -204,7 +205,7 @@ def buildtree(leapfrog1_dE, q, p, u, v, j, e, Emax, E0):
204205 return
205206
206207
207- def leapfrog1_dE (H , q , profile ):
208+ def leapfrog1_dE (H , q , profile , mode ):
208209 """Computes a theano function that computes one leapfrog step and the energy difference between the beginning and end of the trajectory.
209210 Parameters
210211 ----------
@@ -232,6 +233,6 @@ def leapfrog1_dE(H, q, profile):
232233 dE = E - E0
233234
234235 f = theano .function ([q , p , e , E0 ], [q1 , p1 , dE ],
235- profile = profile , mode = self . mode )
236+ profile = profile , mode = mode )
236237 f .trust_input = True
237238 return f
0 commit comments