Skip to content

Commit 0c9b2e4

Browse files
Merge pull request #916 from ParasPuneetSingh/master
OptimizationODE package for ODE solvers
2 parents fd771c8 + e68e640 commit 0c9b2e4

File tree

4 files changed

+176
-0
lines changed

4 files changed

+176
-0
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
- OptimizationMultistartOptimization
3030
- OptimizationNLopt
3131
- OptimizationNOMAD
32+
- OptimizationODE
3233
- OptimizationOptimJL
3334
- OptimizationOptimisers
3435
- OptimizationPRIMA

lib/OptimizationODE/Project.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name = "OptimizationODE"
2+
uuid = "dfa73e59-e644-4d8a-bf84-188d7ecb34e4"
3+
authors = ["Paras Puneet Singh <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
10+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
11+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
12+
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
13+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
15+
[compat]
16+
Optimization = "4"
17+
Reexport = "1"
18+
julia = "1.10"
19+
20+
[extras]
21+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
24+
[targets]
25+
test = ["ADTypes", "Test"]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
module OptimizationODE
2+
3+
using Reexport
4+
@reexport using Optimization, SciMLBase
5+
using OrdinaryDiffEq, SteadyStateDiffEq
6+
7+
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
8+
9+
struct ODEOptimizer{T, T2}
10+
solver::T
11+
dt::T2
12+
end
13+
ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt)
14+
15+
# Solver Constructors (users call these)
16+
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt)
17+
RKChebyshevDescent() = ODEOptimizer(ROCK2())
18+
RKAccelerated() = ODEOptimizer(Tsit5())
19+
HighOrderDescent() = ODEOptimizer(Vern7())
20+
21+
22+
SciMLBase.requiresbounds(::ODEOptimizer) = false
23+
SciMLBase.allowsbounds(::ODEOptimizer) = false
24+
SciMLBase.allowscallback(::ODEOptimizer) = true
25+
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
26+
SciMLBase.requiresgradient(::ODEOptimizer) = true
27+
SciMLBase.requireshessian(::ODEOptimizer) = false
28+
SciMLBase.requiresconsjac(::ODEOptimizer) = false
29+
SciMLBase.requiresconshess(::ODEOptimizer) = false
30+
31+
32+
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer;
33+
callback=Optimization.DEFAULT_CALLBACK, progress=false,
34+
maxiters=nothing, kwargs...)
35+
36+
return OptimizationCache(prob, opt; callback=callback, progress=progress,
37+
maxiters=maxiters, kwargs...)
38+
end
39+
40+
function SciMLBase.__solve(
41+
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42+
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}
43+
44+
dt = cache.opt.dt
45+
maxit = get(cache.solver_args, :maxiters, 1000)
46+
47+
u0 = copy(cache.u0)
48+
p = cache.p
49+
50+
if cache.f.grad === nothing
51+
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
52+
end
53+
54+
function f!(du, u, p, t)
55+
cache.f.grad(du, u, p)
56+
@. du = -du
57+
return nothing
58+
end
59+
60+
ss_prob = SteadyStateProblem(f!, u0, p)
61+
62+
algorithm = DynamicSS(cache.opt.solver)
63+
64+
cb = cache.callback
65+
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
66+
function condition(u, t, integrator)
67+
true
68+
end
69+
function affect!(integrator)
70+
u_now = integrator.u
71+
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p))
72+
Optimization.callback_function(cb, state)
73+
end
74+
cb_struct = DiscreteCallback(condition, affect!)
75+
callback = CallbackSet(cb_struct)
76+
else
77+
callback = nothing
78+
end
79+
80+
solve_kwargs = Dict{Symbol, Any}(:callback => callback)
81+
if !isnothing(maxit)
82+
solve_kwargs[:maxiters] = maxit
83+
end
84+
if dt !== nothing
85+
solve_kwargs[:dt] = dt
86+
end
87+
88+
sol = solve(ss_prob, algorithm; solve_kwargs...)
89+
has_destats = hasproperty(sol, :destats)
90+
has_t = hasproperty(sol, :t) && !isempty(sol.t)
91+
92+
stats = Optimization.OptimizationStats(
93+
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
94+
time = has_t ? sol.t[end] : 0.0,
95+
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
96+
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
97+
hevals = 0
98+
)
99+
100+
SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p);
101+
retcode = ReturnCode.Success,
102+
stats = stats
103+
)
104+
end
105+
106+
end

lib/OptimizationODE/test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Test
2+
using OptimizationODE, SciMLBase, ADTypes
3+
4+
@testset "OptimizationODE Tests" begin
5+
6+
function f(x, p)
7+
return sum(abs2, x)
8+
end
9+
10+
function g!(g, x, p)
11+
@. g = 2 * x
12+
end
13+
14+
x0 = [2.0, -3.0]
15+
p = [5.0]
16+
17+
f_autodiff = OptimizationFunction(f, ADTypes.AutoForwardDiff())
18+
prob_auto = OptimizationProblem(f_autodiff, x0, p)
19+
20+
for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
21+
sol = solve(prob_auto, opt; maxiters=50_000)
22+
@test sol.u [0.0, 0.0] atol=1e-2
23+
@test sol.objective 0.0 atol=1e-2
24+
@test sol.retcode == ReturnCode.Success
25+
end
26+
27+
f_manual = OptimizationFunction(f, SciMLBase.NoAD(); grad=g!)
28+
prob_manual = OptimizationProblem(f_manual, x0)
29+
30+
for opt in (ODEGradientDescent(dt=0.01), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
31+
sol = solve(prob_manual, opt; maxiters=50_000)
32+
@test sol.u [0.0, 0.0] atol=1e-2
33+
@test sol.objective 0.0 atol=1e-2
34+
@test sol.retcode == ReturnCode.Success
35+
end
36+
37+
f_fail = OptimizationFunction(f, SciMLBase.NoAD())
38+
prob_fail = OptimizationProblem(f_fail, x0)
39+
40+
for opt in (ODEGradientDescent(dt=0.001), RKChebyshevDescent(), RKAccelerated(), HighOrderDescent())
41+
@test_throws ErrorException solve(prob_fail, opt; maxiters=20_000)
42+
end
43+
44+
end

0 commit comments

Comments
 (0)