Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit fe1b709

Browse files
Oop improvements
1 parent e79d4cb commit fe1b709

File tree

5 files changed

+338
-90
lines changed

5 files changed

+338
-90
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 160 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p,
8989
end
9090

9191
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
92-
adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false,)
92+
adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false)
9393
if f.grad === nothing
9494
function grad(res, θ)
9595
Enzyme.make_zero!(res)
@@ -198,7 +198,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
198198
seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
199199
Jaccache = Tuple(zero(x) for i in 1:num_cons)
200200
end
201-
201+
202202
y = zeros(eltype(x), num_cons)
203203

204204
function cons_j(J, θ)
@@ -243,7 +243,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
243243
Const,
244244
Duplicated(cons_res, v),
245245
Duplicated(θ, res),
246-
Const(p),
246+
Const(p)
247247
)
248248
end
249249
else
@@ -261,7 +261,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
261261
f.cons,
262262
Duplicated(cons_res, res),
263263
Duplicated(θ, v),
264-
Const(p),
264+
Const(p)
265265
)
266266
end
267267
else
@@ -325,7 +325,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
325325

326326
for i in eachindex(θ)
327327
vec_lagv = lag_vdbθ[i]
328-
h[k+1:k+i] .= @view(vec_lagv[1:i])
328+
h[(k + 1):(k + i)] .= @view(vec_lagv[1:i])
329329
k += i
330330
end
331331
end
@@ -356,8 +356,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
356356
end
357357

358358
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x,
359-
adtype::AutoEnzyme, p,
360-
num_cons = 0)
359+
adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false)
361360
if f.grad === nothing
362361
res = zeros(eltype(x), size(x))
363362
function grad(θ)
@@ -375,12 +374,31 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
375374
grad = (θ) -> f.grad(θ, p)
376375
end
377376

377+
if fg == true
378+
res_fg = zeros(eltype(x), size(x))
379+
function fg!(θ)
380+
Enzyme.make_zero!(res_fg)
381+
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
382+
Const(firstapply),
383+
Active,
384+
Const(f.f),
385+
Enzyme.Duplicated(θ, res_fg),
386+
Const(p)
387+
)[2]
388+
return y, res
389+
end
390+
else
391+
fg! = nothing
392+
end
393+
378394
if f.hess === nothing
379-
function hess(θ)
380-
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
395+
vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
396+
= zeros(eltype(x), length(x))
397+
vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
381398

382-
= zeros(eltype(θ), length(θ))
383-
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
399+
function hess(θ)
400+
Enzyme.make_zero!(bθ)
401+
Enzyme.make_zero!.(vdbθ)
384402

385403
Enzyme.autodiff(Enzyme.Forward,
386404
inner_grad,
@@ -397,9 +415,37 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
397415
hess = (θ) -> f.hess(θ, p)
398416
end
399417

418+
if fgh == true
419+
vdθ_fgh = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
420+
vdbθ_fgh = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
421+
G_fgh = zeros(eltype(x), length(x))
422+
H_fgh = zeros(eltype(x), length(x), length(x))
423+
424+
function fgh!(θ)
425+
Enzyme.make_zero!(G_fgh)
426+
Enzyme.make_zero!(H_fgh)
427+
Enzyme.make_zero!.(vdbθ_fgh)
428+
429+
Enzyme.autodiff(Enzyme.Forward,
430+
inner_grad,
431+
Enzyme.BatchDuplicated(θ, vdθ_fgh),
432+
Enzyme.BatchDuplicatedNoNeed(G_fgh, vdbθ_fgh),
433+
Const(f.f),
434+
Const(p)
435+
)
436+
437+
for i in eachindex(θ)
438+
H_fgh[i, :] .= vdbθ_fgh[i]
439+
end
440+
return G_fgh, H_fgh
441+
end
442+
else
443+
fgh! = nothing
444+
end
445+
400446
if f.hv === nothing
401447
function hv(θ, v)
402-
Enzyme.autodiff(
448+
return Enzyme.autodiff(
403449
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
404450
Const(_f), Const(f.f), Const(p)
405451
)[1]
@@ -411,60 +457,136 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
411457
if f.cons === nothing
412458
cons = nothing
413459
else
414-
cons_oop = (θ) -> f.cons(θ, p)
460+
function cons(θ)
461+
return f.cons(θ, p)
462+
end
415463
end
416464

417-
if f.cons !== nothing && f.cons_j === nothing
418-
seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
465+
if cons !== nothing && f.cons_j === nothing
466+
seeds = Enzyme.onehot(x)
467+
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
468+
419469
function cons_j(θ)
420-
J = Enzyme.autodiff(
421-
Enzyme.Forward, f.cons, BatchDuplicated(θ, seeds), Const(p))[1]
422-
if num_cons == 1
423-
return reduce(vcat, J)
470+
for i in eachindex(Jaccache)
471+
Enzyme.make_zero!(Jaccache[i])
472+
end
473+
y, Jaccache = Enzyme.autodiff(Enzyme.Forward, f.cons, Duplicated,
474+
BatchDuplicated(θ, seeds), Const(p))
475+
if size(y, 1) == 1
476+
return reduce(vcat, Jaccache)
424477
else
425-
return reduce(hcat, J)
478+
return reduce(hcat, Jaccache)
426479
end
427480
end
428481
else
429482
cons_j = (θ) -> f.cons_j(θ, p)
430483
end
431484

432-
if f.cons !== nothing && f.cons_h === nothing
485+
if cons !== nothing && f.cons_vjp == true
486+
res_vjp = zeros(eltype(x), size(x))
487+
cons_vjp_res = zeros(eltype(x), num_cons)
488+
489+
function cons_vjp(θ, v)
490+
Enzyme.make_zero!(res_vjp)
491+
Enzyme.make_zero!(cons_vjp_res)
492+
493+
Enzyme.autodiff(Enzyme.Reverse,
494+
f.cons,
495+
Const,
496+
Duplicated(cons_vjp_res, v),
497+
Duplicated(θ, res_vjp),
498+
Const(p)
499+
)
500+
return res_vjp
501+
end
502+
else
503+
cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p)
504+
end
505+
506+
if cons !== nothing && f.cons_jvp == true
507+
res_jvp = zeros(eltype(x), size(x))
508+
cons_jvp_res = zeros(eltype(x), num_cons)
509+
510+
function cons_jvp(θ, v)
511+
Enzyme.make_zero!(res_jvp)
512+
Enzyme.make_zero!(cons_jvp_res)
513+
514+
Enzyme.autodiff(Enzyme.Forward,
515+
f.cons,
516+
Duplicated(cons_jvp_res, res_jvp),
517+
Duplicated(θ, v),
518+
Const(p)
519+
)
520+
return res_jvp
521+
end
522+
else
523+
cons_jvp = nothing
524+
end
525+
526+
if cons !== nothing && f.cons_h === nothing
527+
cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
528+
cons_bθ = zeros(eltype(x), length(x))
529+
cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
530+
433531
function cons_h(θ)
434-
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
435-
= zeros(eltype(θ), length(θ))
436-
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
437-
res = [zeros(eltype(x), length(θ), length(θ)) for i in 1:num_cons]
438-
for i in 1:num_cons
439-
Enzyme.make_zero!(bθ)
440-
for el in vdbθ
441-
Enzyme.make_zero!(el)
442-
end
532+
return map(1:num_cons) do i
533+
Enzyme.make_zero!(cons_bθ)
534+
Enzyme.make_zero!.(cons_vdbθ)
443535
Enzyme.autodiff(Enzyme.Forward,
444536
cons_f2_oop,
445-
Enzyme.BatchDuplicated(θ, vdθ),
446-
Enzyme.BatchDuplicated(bθ, vdbθ),
537+
Enzyme.BatchDuplicated(θ, cons_vdθ),
538+
Enzyme.BatchDuplicated(cons_bθ, cons_vdbθ),
447539
Const(f.cons),
448540
Const(p),
449541
Const(i))
450-
for j in eachindex(θ)
451-
res[i][j, :] = vdbθ[j]
452-
end
542+
543+
return reduce(hcat, cons_vdbθ)
453544
end
454-
return res
455545
end
456546
else
457547
cons_h = (θ) -> f.cons_h(θ, p)
458548
end
459549

460550
if f.lag_h === nothing
461-
lag_h = nothing # Consider implementing this
551+
lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
552+
lag_bθ = zeros(eltype(x), length(x))
553+
if f.hess_prototype === nothing
554+
lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
555+
else
556+
lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype)))
557+
end
558+
559+
function lag_h(θ, σ, μ)
560+
Enzyme.make_zero!(lag_bθ)
561+
Enzyme.make_zero!.(lag_vdbθ)
562+
563+
Enzyme.autodiff(Enzyme.Forward,
564+
lag_grad,
565+
Enzyme.BatchDuplicated(θ, lag_vdθ),
566+
Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ),
567+
Const(lagrangian),
568+
Const(f.f),
569+
Const(f.cons),
570+
Const(p),
571+
Const(σ),
572+
Const(μ)
573+
)
574+
575+
k = 0
576+
577+
for i in eachindex(θ)
578+
vec_lagv = lag_vdbθ[i]
579+
res[(k + 1):(k + i), :] .= @view(vec_lagv[1:i])
580+
k += i
581+
end
582+
return res
583+
end
462584
else
463585
lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
464586
end
465587

466588
return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv,
467-
cons = cons_oop, cons_j = cons_j, cons_h = cons_h,
589+
cons = cons, cons_j = cons_j, cons_h = cons_h,
468590
hess_prototype = f.hess_prototype,
469591
cons_jac_prototype = f.cons_jac_prototype,
470592
cons_hess_prototype = f.cons_hess_prototype,

0 commit comments

Comments
 (0)