Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/DelayDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using Reexport
using OrdinaryDiffEq, DataStructures, RecursiveArrayTools, Combinatorics

import OrdinaryDiffEq: initialize!, perform_step!, loopfooter!, loopheader!, alg_order,
handle_tstop!, ODEIntegrator, savevalues!, handle_callback_modifiers!
handle_tstop!, ODEIntegrator, savevalues!,
handle_callback_modifiers!, @tight_loop_macros

import DiffEqBase: solve, solve!, init, resize!, u_cache, user_cache, du_cache, full_cache,
deleteat!, terminate!, u_modified!, get_proposed_dt, set_proposed_dt!
Expand Down
179 changes: 108 additions & 71 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@ Update solution of `integrator`, if necessary or forced by `force_save`.
"""
function savevalues!(integrator::DDEIntegrator, force_save=false)
# update ODE integrator
integrator.integrator.u = integrator.u
integrator.integrator.k = integrator.k
integrator.integrator.t = integrator.t

# add steps for interpolation to ODE integrator when needed
OrdinaryDiffEq.ode_addsteps!(integrator.integrator, integrator.f)
update_ode_integrator!(integrator)

# update solution of ODE integrator
savevalues!(integrator.integrator, force_save)
Expand All @@ -19,9 +14,7 @@ function savevalues!(integrator::DDEIntegrator, force_save=false)
reduce_solution!(integrator,
# function values at later time points might be necessary for
# calculation of next step, thus keep those interpolation data
# note: we always have length(integrator.integrator.sol.t) >= 2 since
# save_everystep=true and save_start=true are applied to ODE integrator
integrator.integrator.sol.t[end-1] - maximum(integrator.prob.lags))
integrator.integrator.tprev - maximum(integrator.prob.lags))
end

"""
Expand All @@ -31,9 +24,7 @@ Clean up solution of `integrator`.
"""
function postamble!(integrator::DDEIntegrator)
# update ODE integrator
integrator.integrator.u = integrator.u
integrator.integrator.k = integrator.k
integrator.integrator.t = integrator.t
update_ode_integrator!(integrator)

# clean up solution of ODE integrator
OrdinaryDiffEq.postamble!(integrator.integrator)
Expand All @@ -48,88 +39,128 @@ end
Calculate next step of `integrator`.
"""
function perform_step!(integrator::DDEIntegrator)
# update previous time to extrapolate from current interval
integrator.tprev = integrator.t

# update ODE integrator
integrator.integrator.uprev = integrator.uprev
integrator.integrator.tprev = integrator.tprev
integrator.integrator.fsalfirst = integrator.fsalfirst
integrator.integrator.t = integrator.t
integrator.integrator.dt = integrator.dt

# if dt is greater than the minimal lag, then it's explicit so use fixed-point iteration
if integrator.dt > minimum(integrator.prob.lags)
# cache error estimate of integrator and interpolation data of interval [tprev, t]
# (maybe with already updated entry k[1] = fsalfirst == fsallast, if k[1] points to
# fsalfirst) to be able to reset the corresponding variables in case calculation results
# in numbers that are not finite
recursivecopy!(integrator.k_cache, integrator.k)
integrator.integrator.EEst = integrator.EEst

# perform always at least one calculation
perform_step!(integrator, integrator.cache)

# if dt is greater than the minimal lag, then use a fixed-point iteration
if integrator.dt > minimum(integrator.prob.lags) && isfinite(integrator.EEst)

# update cached error estimate of integrator
integrator.integrator.EEst = integrator.EEst

# save dt, u(tprev) and interpolation data of interval [tprev, t] of ODE
# integrator since they are overwritten by fixed-point iteration
if typeof(integrator.uprev_cache) <: AbstractArray
recursivecopy!(integrator.uprev_cache, integrator.integrator.uprev)
else
integrator.uprev_cache = integrator.integrator.uprev
end
recursivecopy!(integrator.k_integrator_cache, integrator.integrator.k)
integrator.integrator.dtcache = integrator.integrator.dt

# move ODE integrator to interval [t, t+dt] to use interpolation of ODE integrator
# in the next iterations when evaluating the history function
integrator.integrator.t = integrator.t + integrator.dt
integrator.integrator.tprev = integrator.t
integrator.integrator.dt = integrator.dt
if typeof(integrator.integrator.uprev) <: AbstractArray
recursivecopy!(integrator.integrator.uprev, integrator.uprev)
else
integrator.integrator.uprev = integrator.uprev
end

# save these values to correct the extrapolation after the last iteration
tprev_cache = integrator.tprev
t_cache = integrator.t # same as tprev_cache?
uprev_cache = integrator.uprev
numiters=1

numiters = 1
while true

# save u to calculate residuals (u is overwritten in calculation of next step)
# update value u(t+dt) and interpolation data of interval [t, t+dt] that are
# used for the interpolation of the history function in the next iteration
if typeof(integrator.u) <: AbstractArray
copy!(integrator.u_cache, integrator.u)
recursivecopy!(integrator.integrator.u, integrator.u)
else
integrator.u_cache = integrator.u
integrator.integrator.u = integrator.u
end
recursivecopy!(integrator.integrator.k, integrator.k)

# calculate next step
perform_step!(integrator, integrator.cache)

# calculate residuals
if typeof(integrator.resid) <: AbstractArray
@. integrator.resid = (integrator.u - integrator.u_cache) /
@muladd(integrator.fixedpoint_abstol + max(abs(integrator.u),
abs(integrator.u_cache)) *
integrator.fixedpoint_reltol)
# calculate residuals of fixed-point iteration
# can be fixed with new @muladd: https://github.com/JuliaDiffEq/DiffEqBase.jl/pull/57
if typeof(integrator.u) <: AbstractArray
@tight_loop_macros @inbounds for (i, atol, rtol) in
zip(eachindex(integrator.u),
Iterators.cycle(integrator.fixedpoint_abstol),
Iterators.cycle(integrator.fixedpoint_reltol))

integrator.resid[i] = (integrator.u[i] - integrator.integrator.u[i]) /
@muladd(atol + max(abs(integrator.u[i]),
abs(integrator.integrator.u[i])) * rtol)
end
else
integrator.resid = @. (integrator.u - integrator.u_cache) /
integrator.resid = (integrator.u - integrator.integrator.u) /
@muladd(integrator.fixedpoint_abstol + max(abs(integrator.u),
abs(integrator.u_cache)) *
abs(integrator.integrator.u)) *
integrator.fixedpoint_reltol)
end

# stop fixed-point iteration when residuals are small or maximal number of steps is exceeded
fixedpointEEst = integrator.fixedpoint_norm(integrator.resid)
if fixedpointEEst < 1 || numiters > integrator.max_fixedpoint_iters

# stop fixed-point iteration when error estimate of integrator or error estimate
# of fixed-point iteration are not finite
if !isfinite(fixedpointEEst) || !isfinite(integrator.EEst)
# assure that integrator is reset to cached values
integrator.EEst = max(fixedpointEEst, integrator.EEst)
break
end

# special updates of ODE integrator after the first iteration step
# to use interpolation of ODE integrator in the next iterations
# when evaluating the history function
if numiters == 1
integrator.integrator.tprev = integrator.t
integrator.integrator.t = integrator.t + integrator.dt
integrator.integrator.uprev = integrator.u
# update cached value of error estimate of integrator with a combined error
# estimate of both integrator and fixed-point iteration
# this prevents acceptance of steps with poor performance in fixed-point
# iteration
integrator.integrator.EEst = max(fixedpointEEst, integrator.EEst)

# stop fixed-point iteration when error estimate is small or maximal number of
# steps is exceeded
if integrator.integrator.EEst <= 1 || numiters > integrator.max_fixedpoint_iters
# update error estimate with combined error estimate
integrator.EEst = integrator.integrator.EEst
break
end

# update ODE integrator
integrator.integrator.u = integrator.u
integrator.integrator.k = integrator.k

numiters += 1
end

# reset values of DDE integrator after last iteration
integrator.t = t_cache
integrator.tprev = tprev_cache
integrator.uprev = uprev_cache

# update current time of ODE integrator
integrator.integrator.t = t_cache
else # no iterations
perform_step!(integrator, integrator.cache)
# reset ODE integrator to interval [tprev, t] with corresponding values
# u(tprev) and u(t), and interpolation data k of this interval
integrator.integrator.t = integrator.t
integrator.integrator.tprev = integrator.tprev
integrator.integrator.dt = integrator.integrator.dtcache
if typeof(integrator.u) <: AbstractArray
recursivecopy!(integrator.integrator.u, integrator.uprev)
recursivecopy!(integrator.integrator.uprev, integrator.uprev_cache)
else
integrator.integrator.u = integrator.uprev
integrator.integrator.uprev = integrator.uprev_cache
end
recursivecopy!(integrator.integrator.k, integrator.k_integrator_cache)
end

#integrator.u = integrator.integrator.u
#integrator.fsallast = integrator.integrator.fsallast
#if integrator.opts.adaptive
# integrator.EEst = integrator.integrator.EEst
#end
# if error estimate of integrator is not a finite number reset it to last cached error
# estimate or 2, and reset interpolation data of integrator to interpolation data of
# interval [tprev, t] (maybe with updated entry k[1]) before calculation of current step
# then current step will not be accepted, time step dt will be decreased,
# and calculation of next step will be repeated starting with the same
# initial interpolation data
if !isfinite(integrator.EEst)
integrator.EEst = max(2, integrator.integrator.EEst) # EEst must be > 1
recursivecopy!(integrator.k, integrator.k_cache)
end
end

"""
Expand All @@ -140,8 +171,14 @@ Set initial values of `integrator`.
function initialize!(integrator::DDEIntegrator)
initialize!(integrator, integrator.cache, integrator.f)

# set also initial values of ODE integrator
initialize!(integrator.integrator, integrator.cache, integrator.f)
# interpolation data of integrator and ODE integrator have to be cached
# when next step is calculated
integrator.k_cache = recursivecopy(integrator.k)
integrator.k_integrator_cache = recursivecopy(integrator.k)

# copy interpolation data to ODE integrator
integrator.integrator.kshortsize = integrator.kshortsize
integrator.integrator.k = recursivecopy(integrator.k)
end

"""
Expand Down
29 changes: 16 additions & 13 deletions src/integrator_type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ mutable struct DDEIntegrator{algType<:OrdinaryDiffEqAlgorithm,uType,tType,absTyp
f::F
uprev::uType
tprev::tType
u_cache::uType
uprev_cache::uType
k_cache::ksEltype
k_integrator_cache::ksEltype
fixedpoint_abstol::absType
fixedpoint_reltol::relType
resid::residType # This would have to resize for resizing DDE to work
Expand Down Expand Up @@ -50,22 +52,23 @@ mutable struct DDEIntegrator{algType<:OrdinaryDiffEqAlgorithm,uType,tType,absTyp
function DDEIntegrator{algType,uType,tType,absType,relType,residType,tTypeNoUnits,
tdirType,ksEltype,SolType,rateType,F,ProgressType,CacheType,
IType,ProbType,NType,O,tstopsType}(
sol,prob,u,k,t,dt,f,uprev,tprev,u_cache,fixedpoint_abstol,
fixedpoint_reltol,resid,fixedpoint_norm,max_fixedpoint_iters,
minimal_solution,alg,rate_prototype,notsaveat_idxs,dtcache,
dtchangeable,dtpropose,tdir,EEst,qold,q11,iter,saveiter,
saveiter_dense,prog,cache,kshortsize,just_hit_tstop,
accept_step,isout,reeval_fsal,u_modified,opts,integrator,
saveat) where
sol,prob,u,k,t,dt,f,uprev,tprev,uprev_cache,k_cache,
k_integrator_cache,fixedpoint_abstol,fixedpoint_reltol,resid,
fixedpoint_norm,max_fixedpoint_iters,minimal_solution,alg,
rate_prototype,notsaveat_idxs,dtcache,dtchangeable,dtpropose,
tdir,EEst,qold,q11,iter,saveiter,saveiter_dense,prog,cache,
kshortsize,just_hit_tstop,accept_step,isout,reeval_fsal,
u_modified,opts,integrator,saveat) where
{algType<:OrdinaryDiffEqAlgorithm,uType,tType,absType,relType,residType,
tTypeNoUnits,tdirType,ksEltype,SolType,rateType,F,ProgressType,CacheType,IType,
ProbType,NType,O,tstopsType}

new(sol,prob,u,k,t,dt,f,uprev,tprev,u_cache,fixedpoint_abstol,fixedpoint_reltol,
resid,fixedpoint_norm,max_fixedpoint_iters,minimal_solution,alg,rate_prototype,
notsaveat_idxs,dtcache,dtchangeable,dtpropose,tdir,EEst,qold,q11,iter,saveiter,
saveiter_dense,prog,cache,kshortsize,just_hit_tstop,accept_step,isout,
reeval_fsal,u_modified,opts,integrator,saveat)
new(sol,prob,u,k,t,dt,f,uprev,tprev,uprev_cache,k_cache,k_integrator_cache,
fixedpoint_abstol,fixedpoint_reltol,resid,fixedpoint_norm,max_fixedpoint_iters,
minimal_solution,alg,rate_prototype,notsaveat_idxs,dtcache,dtchangeable,
dtpropose,tdir,EEst,qold,q11,iter,saveiter,saveiter_dense,prog,cache,
kshortsize,just_hit_tstop,accept_step,isout,reeval_fsal,u_modified,opts,
integrator,saveat)
end
end

Expand Down
25 changes: 25 additions & 0 deletions src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,28 @@ function build_solution_interpolation(integrator::DDEIntegrator, sol::DiffEqArra
end
end

"""
update_ode_integrator!(integrator::DDEIntegrator)

Update ODE integrator of `integrator` to current time interval, values and interpolation
data of `integrator`.
"""
function update_ode_integrator!(integrator::DDEIntegrator)
# update time interval of ODE integrator
integrator.integrator.t = integrator.t
integrator.integrator.tprev = integrator.tprev
integrator.integrator.dt = integrator.dt

# copy u(tprev) since it is overwritten by integrator at the end of apply_step!
if typeof(integrator.u) <: AbstractArray
recursivecopy!(integrator.integrator.u, integrator.u)
recursivecopy!(integrator.integrator.uprev, integrator.uprev)
else
integrator.integrator.u = integrator.u
integrator.integrator.uprev = integrator.uprev
end

# copy interpolation data (fsalfirst overwritten at the end of apply_step!, which also
# updates k[1] when using chaches for which k[1] points to fsalfirst)
recursivecopy!(integrator.integrator.k, integrator.k)
end
Loading