Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 54162cf

Browse files
format
1 parent ec3067b commit 54162cf

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed

src/cache.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA
3535
kwargs...)
3636
reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
3737
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
38-
f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons,
39-
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt),
40-
fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt),
41-
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
38+
f = OptimizationBase.instantiate_function(
39+
prob.f, reinit_cache, prob.f.adtype, num_cons,
40+
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt), fg = SciMLBase.allowsfg(opt),
41+
fgh = SciMLBase.allowsfgh(opt), cons_j = SciMLBase.requiresconsjac(opt), cons_h = SciMLBase.requiresconshess(opt),
42+
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
4243

4344
if (f.sys === nothing ||
4445
f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) &&

src/function.jl

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ documentation of the `AbstractADType` types.
4646
function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLBase.NoAD,
4747
p, num_cons = 0)
4848
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, p, args...)
49-
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, p, args...) for h in f.hess]
49+
hess = f.hess === nothing ? nothing :
50+
[(H, x, args...) -> h(H, x, p, args...) for h in f.hess]
5051
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
5152
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
5253
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
@@ -63,7 +64,8 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB
6364
expr = symbolify(f.expr)
6465
cons_expr = symbolify.(f.cons_expr)
6566

66-
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
67+
return MultiObjectiveOptimizationFunction{true}(
68+
f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
6769
hv = hv,
6870
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
6971
hess_prototype = hess_prototype,
@@ -74,10 +76,12 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, x, ::SciMLB
7476
observed = f.observed)
7577
end
7678

77-
function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD,
79+
function instantiate_function(
80+
f::MultiObjectiveOptimizationFunction, cache::ReInitCache, ::SciMLBase.NoAD,
7881
num_cons = 0)
7982
jac = f.jac === nothing ? nothing : (J, x, args...) -> f.jac(J, x, cache.p, args...)
80-
hess = f.hess === nothing ? nothing : [(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess]
83+
hess = f.hess === nothing ? nothing :
84+
[(H, x, args...) -> h(H, x, cache.p, args...) for h in f.hess]
8185
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
8286
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
8387
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
@@ -94,7 +98,8 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReIn
9498
expr = symbolify(f.expr)
9599
cons_expr = symbolify.(f.cons_expr)
96100

97-
return MultiObjectiveOptimizationFunction{true}(f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
101+
return MultiObjectiveOptimizationFunction{true}(
102+
f.f, SciMLBase.NoAD(); jac = jac, hess = hess,
98103
hv = hv,
99104
cons = cons, cons_j = cons_j, cons_jvp = cons_jvp, cons_vjp = cons_vjp, cons_h = cons_h,
100105
hess_prototype = hess_prototype,
@@ -107,38 +112,38 @@ end
107112

108113
function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD,
109114
p, num_cons = 0, kwargs...)
110-
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
111-
fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...)
112-
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
113-
fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...)
114-
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
115-
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
116-
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
117-
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
118-
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
119-
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
120-
lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p)
121-
hess_prototype = f.hess_prototype === nothing ? nothing :
122-
convert.(eltype(x), f.hess_prototype)
123-
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
124-
convert.(eltype(x), f.cons_jac_prototype)
125-
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
126-
[convert.(eltype(x), f.cons_hess_prototype[i])
127-
for i in 1:num_cons]
128-
expr = symbolify(f.expr)
129-
cons_expr = symbolify.(f.cons_expr)
130-
131-
return OptimizationFunction{true}(f.f, SciMLBase.NoAD();
132-
grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv,
133-
cons = cons, cons_j = cons_j, cons_h = cons_h,
134-
cons_vjp = cons_vjp, cons_jvp = cons_jvp,
135-
lag_h = lag_h,
136-
hess_prototype = hess_prototype,
137-
cons_jac_prototype = cons_jac_prototype,
138-
cons_hess_prototype = cons_hess_prototype,
139-
expr = expr, cons_expr = cons_expr,
140-
sys = f.sys,
141-
observed = f.observed)
115+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
116+
fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...)
117+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
118+
fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...)
119+
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
120+
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
121+
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
122+
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
123+
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
124+
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
125+
lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p)
126+
hess_prototype = f.hess_prototype === nothing ? nothing :
127+
convert.(eltype(x), f.hess_prototype)
128+
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
129+
convert.(eltype(x), f.cons_jac_prototype)
130+
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
131+
[convert.(eltype(x), f.cons_hess_prototype[i])
132+
for i in 1:num_cons]
133+
expr = symbolify(f.expr)
134+
cons_expr = symbolify.(f.cons_expr)
135+
136+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD();
137+
grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv,
138+
cons = cons, cons_j = cons_j, cons_h = cons_h,
139+
cons_vjp = cons_vjp, cons_jvp = cons_jvp,
140+
lag_h = lag_h,
141+
hess_prototype = hess_prototype,
142+
cons_jac_prototype = cons_jac_prototype,
143+
cons_hess_prototype = cons_hess_prototype,
144+
expr = expr, cons_expr = cons_expr,
145+
sys = f.sys,
146+
observed = f.observed)
142147
end
143148

144149
function instantiate_function(
@@ -162,5 +167,3 @@ function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.Abstra
162167
adpkg = adtypestr[strtind:(open_brkt_ind - 1)]
163168
throw(ArgumentError("The passed automatic differentiation backend choice is not available. Please load the corresponding AD package $adpkg."))
164169
end
165-
166-

0 commit comments

Comments
 (0)