@@ -89,7 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p,
8989end
9090
9191function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x,
92- adtype:: AutoEnzyme , p, num_cons = 0 ; fg = false , fgh = false , )
92+ adtype:: AutoEnzyme , p, num_cons = 0 ; fg = false , fgh = false )
9393 if f. grad === nothing
9494 function grad (res, θ)
9595 Enzyme. make_zero! (res)
@@ -198,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
198198 seeds = Enzyme. onehot (zeros (eltype (x), num_cons))
199199 Jaccache = Tuple (zero (x) for i in 1 : num_cons)
200200 end
201-
201+
202202 y = zeros (eltype (x), num_cons)
203203
204204 function cons_j (J, θ)
@@ -243,7 +243,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
243243 Const,
244244 Duplicated (cons_res, v),
245245 Duplicated (θ, res),
246- Const (p),
246+ Const (p)
247247 )
248248 end
249249 else
@@ -261,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
261261 f. cons,
262262 Duplicated (cons_res, res),
263263 Duplicated (θ, v),
264- Const (p),
264+ Const (p)
265265 )
266266 end
267267 else
@@ -325,7 +325,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
325325
326326 for i in eachindex (θ)
327327 vec_lagv = lag_vdbθ[i]
328- h[k + 1 : k + i ] .= @view (vec_lagv[1 : i])
328+ h[(k + 1 ) : (k + i) ] .= @view (vec_lagv[1 : i])
329329 k += i
330330 end
331331 end
@@ -356,8 +356,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
356356end
357357
358358function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , x,
359- adtype:: AutoEnzyme , p,
360- num_cons = 0 )
359+ adtype:: AutoEnzyme , p, num_cons = 0 ; fg = false , fgh = false )
361360 if f. grad === nothing
362361 res = zeros (eltype (x), size (x))
363362 function grad (θ)
@@ -375,12 +374,31 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
375374 grad = (θ) -> f. grad (θ, p)
376375 end
377376
377+ if fg == true
378+ res_fg = zeros (eltype (x), size (x))
379+ function fg! (θ)
380+ Enzyme. make_zero! (res_fg)
381+ y = Enzyme. autodiff (Enzyme. ReverseWithPrimal,
382+ Const (firstapply),
383+ Active,
384+ Const (f. f),
385+ Enzyme. Duplicated (θ, res_fg),
386+ Const (p)
387+ )[2 ]
388+ return y, res
389+ end
390+ else
391+ fg! = nothing
392+ end
393+
378394 if f. hess === nothing
379- function hess (θ)
380- vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
395+ vdθ = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
396+ bθ = zeros (eltype (x), length (x))
397+ vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
381398
382- bθ = zeros (eltype (θ), length (θ))
383- vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
399+ function hess (θ)
400+ Enzyme. make_zero! (bθ)
401+ Enzyme. make_zero! .(vdbθ)
384402
385403 Enzyme. autodiff (Enzyme. Forward,
386404 inner_grad,
@@ -397,9 +415,37 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
397415 hess = (θ) -> f. hess (θ, p)
398416 end
399417
418+ if fgh == true
419+ vdθ_fgh = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
420+ vdbθ_fgh = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
421+ G_fgh = zeros (eltype (x), length (x))
422+ H_fgh = zeros (eltype (x), length (x), length (x))
423+
424+ function fgh! (θ)
425+ Enzyme. make_zero! (G_fgh)
426+ Enzyme. make_zero! (H_fgh)
427+ Enzyme. make_zero! .(vdbθ_fgh)
428+
429+ Enzyme. autodiff (Enzyme. Forward,
430+ inner_grad,
431+ Enzyme. BatchDuplicated (θ, vdθ_fgh),
432+ Enzyme. BatchDuplicatedNoNeed (G_fgh, vdbθ_fgh),
433+ Const (f. f),
434+ Const (p)
435+ )
436+
437+ for i in eachindex (θ)
438+ H_fgh[i, :] .= vdbθ_fgh[i]
439+ end
440+ return G_fgh, H_fgh
441+ end
442+ else
443+ fgh! = nothing
444+ end
445+
400446 if f. hv === nothing
401447 function hv (θ, v)
402- Enzyme. autodiff (
448+ return Enzyme. autodiff (
403449 Enzyme. Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated (θ, v),
404450 Const (_f), Const (f. f), Const (p)
405451 )[1 ]
@@ -411,60 +457,136 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
411457 if f. cons === nothing
412458 cons = nothing
413459 else
414- cons_oop = (θ) -> f. cons (θ, p)
460+ function cons (θ)
461+ return f. cons (θ, p)
462+ end
415463 end
416464
417- if f. cons != = nothing && f. cons_j === nothing
418- seeds = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
465+ if cons != = nothing && f. cons_j === nothing
466+ seeds = Enzyme. onehot (x)
467+ Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
468+
419469 function cons_j (θ)
420- J = Enzyme. autodiff (
421- Enzyme. Forward, f. cons, BatchDuplicated (θ, seeds), Const (p))[1 ]
422- if num_cons == 1
423- return reduce (vcat, J)
470+ for i in eachindex (Jaccache)
471+ Enzyme. make_zero! (Jaccache[i])
472+ end
473+ y, Jaccache = Enzyme. autodiff (Enzyme. Forward, f. cons, Duplicated,
474+ BatchDuplicated (θ, seeds), Const (p))
475+ if size (y, 1 ) == 1
476+ return reduce (vcat, Jaccache)
424477 else
425- return reduce (hcat, J )
478+ return reduce (hcat, Jaccache )
426479 end
427480 end
428481 else
429482 cons_j = (θ) -> f. cons_j (θ, p)
430483 end
431484
432- if f. cons != = nothing && f. cons_h === nothing
485+ if cons != = nothing && f. cons_vjp == true
486+ res_vjp = zeros (eltype (x), size (x))
487+ cons_vjp_res = zeros (eltype (x), num_cons)
488+
489+ function cons_vjp (θ, v)
490+ Enzyme. make_zero! (res_vjp)
491+ Enzyme. make_zero! (cons_vjp_res)
492+
493+ Enzyme. autodiff (Enzyme. Reverse,
494+ f. cons,
495+ Const,
496+ Duplicated (cons_vjp_res, v),
497+ Duplicated (θ, res_vjp),
498+ Const (p)
499+ )
500+ return res_vjp
501+ end
502+ else
503+ cons_vjp = (θ, σ) -> f. cons_vjp (θ, σ, p)
504+ end
505+
506+ if cons != = nothing && f. cons_jvp == true
507+ res_jvp = zeros (eltype (x), size (x))
508+ cons_jvp_res = zeros (eltype (x), num_cons)
509+
510+ function cons_jvp (θ, v)
511+ Enzyme. make_zero! (res_jvp)
512+ Enzyme. make_zero! (cons_jvp_res)
513+
514+ Enzyme. autodiff (Enzyme. Forward,
515+ f. cons,
516+ Duplicated (cons_jvp_res, res_jvp),
517+ Duplicated (θ, v),
518+ Const (p)
519+ )
520+ return res_jvp
521+ end
522+ else
523+ cons_jvp = nothing
524+ end
525+
526+ if cons != = nothing && f. cons_h === nothing
527+ cons_vdθ = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
528+ cons_bθ = zeros (eltype (x), length (x))
529+ cons_vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
530+
433531 function cons_h (θ)
434- vdθ = Tuple ((Array (r) for r in eachrow (I (length (θ)) * one (eltype (θ)))))
435- bθ = zeros (eltype (θ), length (θ))
436- vdbθ = Tuple (zeros (eltype (θ), length (θ)) for i in eachindex (θ))
437- res = [zeros (eltype (x), length (θ), length (θ)) for i in 1 : num_cons]
438- for i in 1 : num_cons
439- Enzyme. make_zero! (bθ)
440- for el in vdbθ
441- Enzyme. make_zero! (el)
442- end
532+ return map (1 : num_cons) do i
533+ Enzyme. make_zero! (cons_bθ)
534+ Enzyme. make_zero! .(cons_vdbθ)
443535 Enzyme. autodiff (Enzyme. Forward,
444536 cons_f2_oop,
445- Enzyme. BatchDuplicated (θ, vdθ ),
446- Enzyme. BatchDuplicated (bθ, vdbθ ),
537+ Enzyme. BatchDuplicated (θ, cons_vdθ ),
538+ Enzyme. BatchDuplicated (cons_bθ, cons_vdbθ ),
447539 Const (f. cons),
448540 Const (p),
449541 Const (i))
450- for j in eachindex (θ)
451- res[i][j, :] = vdbθ[j]
452- end
542+
543+ return reduce (hcat, cons_vdbθ)
453544 end
454- return res
455545 end
456546 else
457547 cons_h = (θ) -> f. cons_h (θ, p)
458548 end
459549
460550 if f. lag_h === nothing
461- lag_h = nothing # Consider implementing this
551+ lag_vdθ = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
552+ lag_bθ = zeros (eltype (x), length (x))
553+ if f. hess_prototype === nothing
554+ lag_vdbθ = Tuple (zeros (eltype (x), length (x)) for i in eachindex (x))
555+ else
556+ lag_vdbθ = Tuple ((copy (r) for r in eachrow (f. hess_prototype)))
557+ end
558+
559+ function lag_h (θ, σ, μ)
560+ Enzyme. make_zero! (lag_bθ)
561+ Enzyme. make_zero! .(lag_vdbθ)
562+
563+ Enzyme. autodiff (Enzyme. Forward,
564+ lag_grad,
565+ Enzyme. BatchDuplicated (θ, lag_vdθ),
566+ Enzyme. BatchDuplicatedNoNeed (lag_bθ, lag_vdbθ),
567+ Const (lagrangian),
568+ Const (f. f),
569+ Const (f. cons),
570+ Const (p),
571+ Const (σ),
572+ Const (μ)
573+ )
574+
575+ k = 0
576+
577+ for i in eachindex (θ)
578+ vec_lagv = lag_vdbθ[i]
579+ res[(k + 1 ): (k + i), :] .= @view (vec_lagv[1 : i])
580+ k += i
581+ end
582+ return res
583+ end
462584 else
463585 lag_h = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
464586 end
465587
466588 return OptimizationFunction {false} (f. f, adtype; grad = grad, hess = hess, hv = hv,
467- cons = cons_oop , cons_j = cons_j, cons_h = cons_h,
589+ cons = cons , cons_j = cons_j, cons_h = cons_h,
468590 hess_prototype = f. hess_prototype,
469591 cons_jac_prototype = f. cons_jac_prototype,
470592 cons_hess_prototype = f. cons_hess_prototype,
0 commit comments