Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 9813fd4

Browse files
vjp and jvp and jacobian based on dimensions
1 parent 6d4f6d7 commit 9813fd4

File tree

2 files changed

+87
-34
lines changed

2 files changed

+87
-34
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function inner_grad(θ, bθ, f, p)
2424
Const(f),
2525
Enzyme.Duplicated(θ, bθ),
2626
Const(p)
27-
),
27+
)
2828
return nothing
2929
end
3030

@@ -89,8 +89,7 @@ function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p,
8989
end
9090

9191
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
92-
adtype::AutoEnzyme, p,
93-
num_cons = 0)
92+
adtype::AutoEnzyme, p, num_cons = 0; fg = false, fgh = false,)
9493
if f.grad === nothing
9594
function grad(res, θ)
9695
Enzyme.make_zero!(res)
@@ -106,16 +105,20 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
106105
grad = (G, θ) -> f.grad(G, θ, p)
107106
end
108107

109-
function fg!(res, θ)
110-
Enzyme.make_zero!(res)
111-
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
112-
Const(firstapply),
113-
Active,
114-
Const(f.f),
115-
Enzyme.Duplicated(θ, res),
116-
Const(p)
117-
)[2]
118-
return y
108+
if fg == true
109+
function fg!(res, θ)
110+
Enzyme.make_zero!(res)
111+
y = Enzyme.autodiff(Enzyme.ReverseWithPrimal,
112+
Const(firstapply),
113+
Active,
114+
Const(f.f),
115+
Enzyme.Duplicated(θ, res),
116+
Const(p)
117+
)[2]
118+
return y
119+
end
120+
else
121+
fg! = nothing
119122
end
120123

121124
if f.hess === nothing
@@ -130,7 +133,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
130133
end
131134

132135
function hess(res, θ)
133-
Enzyme.make_zero!.(vdθ)
134136
Enzyme.make_zero!(bθ)
135137
Enzyme.make_zero!.(vdbθ)
136138

@@ -150,8 +152,25 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
150152
hess = (H, θ) -> f.hess(H, θ, p)
151153
end
152154

153-
function fgh!(G, H, θ)
154-
155+
if fgh == true
156+
function fgh!(G, H, θ)
157+
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
158+
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
159+
160+
Enzyme.autodiff(Enzyme.Forward,
161+
inner_grad,
162+
Enzyme.BatchDuplicated(θ, vdθ),
163+
Enzyme.BatchDuplicatedNoNeed(G, vdbθ),
164+
Const(f.f),
165+
Const(p)
166+
)
167+
168+
for i in eachindex(θ)
169+
H[i, :] .= vdbθ[i]
170+
end
171+
end
172+
else
173+
fgh! = nothing
155174
end
156175

157176
if f.hv === nothing
@@ -175,13 +194,19 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
175194
seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
176195
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
177196
y = zeros(eltype(x), num_cons)
197+
178198
function cons_j(J, θ)
179199
for i in 1:length(θ)
180200
Enzyme.make_zero!(Jaccache[i])
181201
end
182202
Enzyme.make_zero!(y)
183-
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
184-
BatchDuplicated(θ, seeds), Const(p))
203+
if num_cons > length(θ)
204+
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
205+
BatchDuplicated(θ, seeds), Const(p))
206+
else
207+
Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
208+
BatchDuplicated(θ, Jaccache), Const(p))
209+
end
185210
for i in 1:length(θ)
186211
if J isa Vector
187212
J[i] = Jaccache[i][1]
@@ -194,35 +219,63 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
194219
cons_j = (J, θ) -> f.cons_j(J, θ, p)
195220
end
196221

197-
if cons !== nothing && f.cons_vjp === nothing
198-
function cons_vjp(res, θ, v)
199-
222+
if cons !== nothing && f.cons_vjp == true
223+
cons_res = zeros(eltype(x), num_cons)
224+
function cons_vjp!(res, θ, v)
225+
Enzyme.make_zero!(res)
226+
Enzyme.make_zero!(cons_res)
227+
228+
Enzyme.autodiff(Enzyme.Reverse,
229+
f.cons,
230+
Const,
231+
Duplicated(cons_res, v),
232+
Duplicated(θ, res),
233+
Const(p),
234+
)
235+
end
236+
else
237+
cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p)
238+
end
239+
240+
if cons !== nothing && f.cons_jvp == true
241+
cons_res = zeros(eltype(x), num_cons)
242+
243+
function cons_jvp!(res, θ, v)
244+
Enzyme.make_zero!(res)
245+
Enzyme.make_zero!(cons_res)
246+
247+
Enzyme.autodiff(Enzyme.Forward,
248+
f.cons,
249+
Duplicated(cons_res, res),
250+
Duplicated(θ, v),
251+
Const(p),
252+
)
200253
end
201254
else
202-
cons_vjp = (θ, σ) -> f.cons_vjp(θ, σ, p)
255+
cons_vjp! = nothing
203256
end
204257

205258
if cons !== nothing && f.cons_h === nothing
259+
cons_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
260+
cons_bθ = zeros(eltype(x), length(x))
261+
cons_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x))
262+
206263
function cons_h(res, θ)
207-
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * one(eltype(θ)))))
208-
= zeros(eltype(θ), length(θ))
209-
vdbθ = Tuple(zeros(eltype(θ), length(θ)) for i in eachindex(θ))
264+
Enzyme.make_zero!(cons_bθ)
265+
Enzyme.make_zero!.(cons_vdbθ)
266+
210267
for i in 1:num_cons
211-
bθ .= zero(eltype(bθ))
212-
for el in vdbθ
213-
Enzyme.make_zero!(el)
214-
end
215268
Enzyme.autodiff(Enzyme.Forward,
216269
cons_f2,
217-
Enzyme.BatchDuplicated(θ, vdθ),
218-
Enzyme.BatchDuplicated(bθ, vdbθ),
270+
Enzyme.BatchDuplicated(θ, cons_vdθ),
271+
Enzyme.BatchDuplicated(bθ, cons_vdbθ),
219272
Const(f.cons),
220273
Const(p),
221274
Const(num_cons),
222275
Const(i))
223276

224277
for j in eachindex(θ)
225-
res[i][j, :] .= vdbθ[j]
278+
res[i][j, :] .= cons_vdbθ[j]
226279
end
227280
end
228281
end
@@ -242,7 +295,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
242295
end
243296

244297
function lag_h(h, θ, σ, μ)
245-
Enzyme.make_zero!.(lag_vdθ)
246298
Enzyme.make_zero!(lag_bθ)
247299
Enzyme.make_zero!.(lag_vdbθ)
248300

test/adtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ H2 = Array{Float64}(undef, 2, 2)
2626
g!(G1, x0)
2727
h!(H1, x0)
2828

29-
cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2])
29+
cons = (res, x, p) -> (res .= [x[1]^2 + x[2]^2]; return nothing)
3030
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoModelingToolkit(), cons = cons)
3131
optprob = OptimizationBase.instantiate_function(optf, x0,
3232
OptimizationBase.AutoModelingToolkit(),
@@ -47,6 +47,7 @@ optprob.cons_h(H3, x0)
4747

4848
function con2_c(res, x, p)
4949
res .= [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
50+
return nothing
5051
end
5152
optf = OptimizationFunction(rosenbrock,
5253
OptimizationBase.AutoModelingToolkit(),

0 commit comments

Comments
 (0)