Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit ec157c3

Browse files
Update NoAD dispatch
1 parent 358b1e9 commit ec157c3

File tree

1 file changed

+36
-51
lines changed

1 file changed

+36
-51
lines changed

src/function.jl

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -105,64 +105,49 @@ function instantiate_function(f::MultiObjectiveOptimizationFunction, cache::ReIn
105105
observed = f.observed)
106106
end
107107

108-
109108
function instantiate_function(f::OptimizationFunction{true}, x, ::SciMLBase.NoAD,
110-
p, num_cons = 0)
111-
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
112-
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
113-
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
114-
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
115-
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
116-
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
117-
hess_prototype = f.hess_prototype === nothing ? nothing :
118-
convert.(eltype(x), f.hess_prototype)
119-
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
120-
convert.(eltype(x), f.cons_jac_prototype)
121-
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
122-
[convert.(eltype(x), f.cons_hess_prototype[i])
123-
for i in 1:num_cons]
124-
expr = symbolify(f.expr)
125-
cons_expr = symbolify.(f.cons_expr)
126-
127-
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
128-
hv = hv,
129-
cons = cons, cons_j = cons_j, cons_h = cons_h,
130-
hess_prototype = hess_prototype,
131-
cons_jac_prototype = cons_jac_prototype,
132-
cons_hess_prototype = cons_hess_prototype,
133-
expr = expr, cons_expr = cons_expr,
134-
sys = f.sys,
135-
observed = f.observed)
109+
p, num_cons = 0, kwargs...)
110+
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, p, args...)
111+
fg = f.fg === nothing ? nothing : (G, x, args...) -> f.fg(G, x, p, args...)
112+
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, p, args...)
113+
fgh = f.fgh === nothing ? nothing : (G, H, x, args...) -> f.fgh(G, H, x, p, args...)
114+
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, p, args...)
115+
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, p)
116+
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, p)
117+
cons_vjp = f.cons_vjp === nothing ? nothing : (res, x) -> f.cons_vjp(res, x, p)
118+
cons_jvp = f.cons_jvp === nothing ? nothing : (res, x) -> f.cons_jvp(res, x, p)
119+
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, p)
120+
lag_h = f.lag_h === nothing ? nothing : (res, x) -> f.lag_h(res, x, p)
121+
hess_prototype = f.hess_prototype === nothing ? nothing :
122+
convert.(eltype(x), f.hess_prototype)
123+
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
124+
convert.(eltype(x), f.cons_jac_prototype)
125+
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
126+
[convert.(eltype(x), f.cons_hess_prototype[i])
127+
for i in 1:num_cons]
128+
expr = symbolify(f.expr)
129+
cons_expr = symbolify.(f.cons_expr)
130+
131+
return OptimizationFunction{true}(f.f, SciMLBase.NoAD();
132+
grad = grad, fg = fg, hess = hess, fgh = fgh, hv = hv,
133+
cons = cons, cons_j = cons_j, cons_h = cons_h,
134+
cons_vjp = cons_vjp, cons_jvp = cons_jvp,
135+
lag_h = lag_h,
136+
hess_prototype = hess_prototype,
137+
cons_jac_prototype = cons_jac_prototype,
138+
cons_hess_prototype = cons_hess_prototype,
139+
expr = expr, cons_expr = cons_expr,
140+
sys = f.sys,
141+
observed = f.observed)
136142
end
137143

138144
function instantiate_function(
139145
f::OptimizationFunction{true}, cache::ReInitCache, ::SciMLBase.NoAD,
140146
num_cons = 0, kwargs...)
141-
grad = f.grad === nothing ? nothing : (G, x, args...) -> f.grad(G, x, cache.p, args...)
142-
hess = f.hess === nothing ? nothing : (H, x, args...) -> f.hess(H, x, cache.p, args...)
143-
hv = f.hv === nothing ? nothing : (H, x, v, args...) -> f.hv(H, x, v, cache.p, args...)
144-
cons = f.cons === nothing ? nothing : (res, x) -> f.cons(res, x, cache.p)
145-
cons_j = f.cons_j === nothing ? nothing : (res, x) -> f.cons_j(res, x, cache.p)
146-
cons_h = f.cons_h === nothing ? nothing : (res, x) -> f.cons_h(res, x, cache.p)
147-
hess_prototype = f.hess_prototype === nothing ? nothing :
148-
convert.(eltype(cache.u0), f.hess_prototype)
149-
cons_jac_prototype = f.cons_jac_prototype === nothing ? nothing :
150-
convert.(eltype(cache.u0), f.cons_jac_prototype)
151-
cons_hess_prototype = f.cons_hess_prototype === nothing ? nothing :
152-
[convert.(eltype(cache.u0), f.cons_hess_prototype[i])
153-
for i in 1:num_cons]
154-
expr = symbolify(f.expr)
155-
cons_expr = symbolify.(f.cons_expr)
147+
x = cache.u0
148+
p = cache.p
156149

157-
return OptimizationFunction{true}(f.f, SciMLBase.NoAD(); grad = grad, hess = hess,
158-
hv = hv,
159-
cons = cons, cons_j = cons_j, cons_h = cons_h,
160-
hess_prototype = hess_prototype,
161-
cons_jac_prototype = cons_jac_prototype,
162-
cons_hess_prototype = cons_hess_prototype,
163-
expr = expr, cons_expr = cons_expr,
164-
sys = f.sys,
165-
observed = f.observed)
150+
return instantiate_function(f, x, SciMLBase.NoAD(), p, num_cons, kwargs...)
166151
end
167152

168153
function instantiate_function(f::OptimizationFunction, x, adtype::ADTypes.AbstractADType,

0 commit comments

Comments
 (0)