Skip to content

Commit 724a9af

Browse files
Merge pull request #95 from SciML/remake
specialize ODEProblem remake for compile times
2 parents 2df7c8c + 488a7f1 commit 724a9af

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

src/remake.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,37 @@ end
3333

3434
isrecompile(prob::ODEProblem{iip}) where {iip} = (prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
3535

36-
function remake(thing::ODEProblem; kwargs...)
37-
T = remaker_of(thing)
38-
tup = merge(merge(struct_as_namedtuple(thing),thing.kwargs),kwargs)
39-
if !isrecompile(thing)
40-
if isinplace(thing)
41-
f = wrapfun_iip(unwrap_fw(tup.f.f),(tup.u0,tup.u0,tup.p,tup.tspan[1]))
36+
function remake(prob::ODEProblem; f=missing,
37+
u0=missing,
38+
tspan=missing,
39+
p=missing,
40+
kwargs...)
41+
if f === missing
42+
f = prob.f
43+
elseif !isrecompile(prob)
44+
if isinplace(prob)
45+
f = wrapfun_iip(unwrap_fw(f),(u0,u0,p,tspan[1]))
4246
else
43-
f = wrapfun_oop(unwrap_fw(tup.f.f),(tup.u0,tup.p,tup.tspan[1]))
47+
f = wrapfun_oop(unwrap_fw(f),(u0,p,tspan[1]))
4448
end
45-
tup2 = (f = convert(ODEFunction{isinplace(thing)},f),)
46-
tup = merge(tup, tup2)
49+
f = convert(ODEFunction{isinplace(prob)},f)
50+
elseif prob.f isa ODEFunction # avoid the SplitFunction etc. cases
51+
f = convert(ODEFunction{isinplace(prob)},f)
52+
end
53+
54+
if u0 === missing
55+
u0 = prob.u0
56+
end
57+
58+
if tspan === missing
59+
tspan = prob.tspan
60+
end
61+
62+
if p === missing
63+
p = prob.p
4764
end
48-
T(; tup...)
65+
66+
ODEProblem{isinplace(prob)}(f,u0,tspan,p,prob.problem_type;prob.kwargs..., kwargs...)
4967
end
5068

5169
function remake(thing::AbstractJumpProblem; kwargs...)

test/downstream/symbol_indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ eqs = [D(x) ~ σ*(y-x),
88
D(y) ~ x*-z)-y,
99
D(z) ~ x*y - β*z]
1010

11-
lorenz1 = ODESystem(eqs,name=:lorenz1)
12-
lorenz2 = ODESystem(eqs,name=:lorenz2)
11+
@named lorenz1 = ODESystem(eqs)
12+
@named lorenz2 = ODESystem(eqs)
1313

1414
@parameters γ
1515
@variables a(t) α(t)
1616
connections = [0 ~ lorenz1.x + lorenz2.y + a*γ,
1717
α ~ 2lorenz1.x + a*γ]
18-
sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2])
18+
@named sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2])
1919
sys_simplified = structural_simplify(sys)
2020

2121
u0 = [lorenz1.x => 1.0,

0 commit comments

Comments
 (0)