@@ -24,7 +24,7 @@ function inner_grad(θ, bθ, f, p)
2424 Const (f),
2525 Enzyme. Duplicated (θ, bθ),
2626 Const (p)
27- ),
27+ )
2828 return nothing
2929end
3030
@@ -89,8 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p,
8989end
9090
9191function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
92- adtype:: AutoEnzyme , p,
93- num_cons = 0 )
92+ adtype:: AutoEnzyme , p, num_cons = 0 ; fg = false , fgh = false ,)
9493 if f. grad === nothing
9594 function grad (res, θ)
9695 Enzyme. make_zero! (res)
@@ -106,16 +105,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
106105 grad = (G, θ) -> f. grad (G, θ, p)
107106 end
108107
109- function fg! (res, θ)
110- Enzyme. make_zero! (res)
111- y = Enzyme. autodiff (Enzyme. ReverseWithPrimal,
112- Const (firstapply),
113- Active,
114- Const (f. f),
115- Enzyme. Duplicated (θ, res),
116- Const (p)
117- )[2 ]
118- return y
108+ if fg == true
109+ function fg! (res, θ)
110+ Enzyme. make_zero! (res)
111+ y = Enzyme. autodiff (Enzyme. ReverseWithPrimal,
112+ Const (firstapply),
113+ Active,
114+ Const (f. f),
115+ Enzyme. Duplicated (θ, res),
116+ Const (p)
117+ )[2 ]
118+ return y
119+ end
120+ else
121+ fg! = nothing
119122 end
120123
121124 if f. hess === nothing
@@ -130,7 +133,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
130133 end
131134
132135 function hess (res, θ)
133- Enzyme. make_zero! .(vdθ)
134136 Enzyme. make_zero! (bθ)
135137 Enzyme. make_zero! .(vdbθ)
136138
@@ -150,8 +152,25 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
150152 hess = (H, θ) -> f. hess (H, θ, p)
151153 end
152154
153- function fgh! (G, H, θ)
154-
155+ if fgh == true
156+ function fgh! (G, H, θ)
157+ vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
158+ vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
159+
160+ Enzyme. autodiff (Enzyme. Forward,
161+ inner_grad,
162+ Enzyme. BatchDuplicated (θ, vdθ),
163+ Enzyme. BatchDuplicatedNoNeed (G, vdbθ),
164+ Const (f. f),
165+ Const (p)
166+ )
167+
168+ for i in eachindex (θ)
169+ H[i, :] .= vdbθ[i]
170+ end
171+ end
172+ else
173+ fgh! = nothing
155174 end
156175
157176 if f. hv === nothing
@@ -175,13 +194,19 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
175194 seeds = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
176195 Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
177196 y = zeros (eltype (x), num_cons)
197+
178198 function cons_j (J, θ)
179199 for i in 1 : length (θ)
180200 Enzyme. make_zero! (Jaccache[i])
181201 end
182202 Enzyme. make_zero! (y)
183- Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
184- BatchDuplicated (θ, seeds), Const (p))
203+ if num_cons > length (θ)
204+ Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
205+ BatchDuplicated (θ, seeds), Const (p))
206+ else
207+ Enzyme. autodiff (Enzyme. Reverse, f. cons, BatchDuplicated (y, seeds),
208+ BatchDuplicated (θ, Jaccache), Const (p))
209+ end
185210 for i in 1 : length (θ)
186211 if J isa Vector
187212 J[i] = Jaccache[i][1 ]
@@ -194,35 +219,63 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
194219 cons_j = (J, θ) -> f. cons_j (J, θ, p)
195220 end
196221
197- if cons != = nothing && f. cons_vjp === nothing
198- function cons_vjp (res, θ, v)
199-
222+ if cons != = nothing && f. cons_vjp == true
223+ cons_res = zeros (eltype (x), num_cons)
224+ function cons_vjp! (res, θ, v)
225+ Enzyme. make_zero! (res)
226+ Enzyme. make_zero! (cons_res)
227+
228+ Enzyme. autodiff (Enzyme. Reverse,
229+ f. cons,
230+ Const,
231+ Duplicated (cons_res, v),
232+ Duplicated (θ, res),
233+ Const (p),
234+ )
235+ end
236+ else
237+ cons_vjp! = (θ, σ) -> f. cons_vjp (θ, σ, p)
238+ end
239+
240+ if cons != = nothing && f. cons_jvp == true
241+ cons_res = zeros (eltype (x), num_cons)
242+
243+ function cons_jvp! (res, θ, v)
244+ Enzyme. make_zero! (res)
245+ Enzyme. make_zero! (cons_res)
246+
247+ Enzyme. autodiff (Enzyme. Forward,
248+ f. cons,
249+ Duplicated (cons_res, res),
250+ Duplicated (θ, v),
251+ Const (p),
252+ )
200253 end
201254 else
202- cons_vjp = (θ, σ) -> f . cons_vjp (θ, σ, p)
255+ cons_vjp! = nothing
203256 end
204257
205258 if cons != = nothing && f. cons_h === nothing
259+ cons_vdθ = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
260+ cons_bθ = zeros (eltype (x), length (x))
261+ cons_vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
262+
206263 function cons_h (res, θ)
207- vdθ = Tuple (( Array (r) for r in eachrow ( I ( length (θ)) * one ( eltype (θ)))) )
208- bθ = zeros ( eltype (θ), length (θ) )
209- vdbθ = Tuple ( zeros ( eltype (θ), length (θ)) for i in eachindex (θ))
264+ Enzyme . make_zero! (cons_bθ )
265+ Enzyme . make_zero! .(cons_vdbθ )
266+
210267 for i in 1 : num_cons
211- bθ .= zero (eltype (bθ))
212- for el in vdbθ
213- Enzyme. make_zero! (el)
214- end
215268 Enzyme. autodiff (Enzyme. Forward,
216269 cons_f2,
217- Enzyme. BatchDuplicated (θ, vdθ ),
218- Enzyme. BatchDuplicated (bθ, vdbθ ),
270+ Enzyme. BatchDuplicated (θ, cons_vdθ ),
271+ Enzyme. BatchDuplicated (bθ, cons_vdbθ ),
219272 Const (f. cons),
220273 Const (p),
221274 Const (num_cons),
222275 Const (i))
223276
224277 for j in eachindex (θ)
225- res[i][j, :] .= vdbθ [j]
278+ res[i][j, :] .= cons_vdbθ [j]
226279 end
227280 end
228281 end
@@ -242,7 +295,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
242295 end
243296
244297 function lag_h (h, θ, σ, μ)
245- Enzyme. make_zero! .(lag_vdθ)
246298 Enzyme. make_zero! (lag_bθ)
247299 Enzyme. make_zero! .(lag_vdbθ)
248300
0 commit comments