Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
Expand All @@ -39,8 +40,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Expand All @@ -55,8 +56,8 @@ DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
DiffEqBaseDistributionsExt = "Distributions"
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseGTPSAExt = "GTPSA"
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseMPIExt = "MPI"
DiffEqBaseMeasurementsExt = "Measurements"
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
Expand All @@ -82,8 +83,8 @@ FastPower = "1.1"
ForwardDiff = "0.10, 1"
FunctionWrappers = "1.0"
FunctionWrappersWrappers = "0.1"
GeneralizedGenerated = "0.3"
GTPSA = "1.4"
GeneralizedGenerated = "0.3"
LinearAlgebra = "1.9"
Logging = "1.9"
MPI = "0.20"
Expand All @@ -105,6 +106,7 @@ SparseArrays = "1.9"
Static = "1"
StaticArraysCore = "1.4"
Statistics = "1"
SymbolicIndexingInterface = "0.3.39"
Tracker = "0.2"
TruncatedStacktraces = "1"
Unitful = "1"
Expand All @@ -129,10 +131,9 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
2 changes: 2 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ Reexport.@reexport using SciMLBase

SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true

import SymbolicIndexingInterface as SII

## Extension Functions

eltypedual(x) = false
Expand Down
34 changes: 33 additions & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,32 @@ function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

"""
$(TYPEDSIGNATURES)

Given the index provider `indp` used to construct the problem `prob` being solved, return
an updated `prob` to be used for solving. All implementations should accept arbitrary
keyword arguments.

Should be called before the problem is solved, after performing type-promotion on the
problem.
"""
function get_updated_symbolic_problem(indp, prob; kw...)
return prob
end

"""
$(TYPEDSIGNATURES)

Get the innermost index provider using `SII.symbolic_container`.
"""
function _get_root_indp(indp)
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp
return _get_root_indp(sc)
end
return indp
end

function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
kwargs...)
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
Expand Down Expand Up @@ -1213,24 +1239,27 @@ function checkkwargs(kwargshandle; kwargs...)
end

function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...)
prob
get_updated_symbolic_problem(_get_root_indp(prob), prob)
end

function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, Inf, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)
end

function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = promote_u0(u0, p, nothing)
remake(prob; u0 = u0, p = p)
end

function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
u0 = promote_u0(u0, p, nothing)
Expand All @@ -1252,6 +1281,7 @@ function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...;
end

function get_concrete_problem(prob, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
Expand All @@ -1270,6 +1300,7 @@ function get_concrete_problem(prob, isadapt; kwargs...)
end

function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
Expand All @@ -1293,6 +1324,7 @@ function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
end

function get_concrete_problem(prob::DDEProblem, isadapt; kwargs...)
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
p = get_concrete_p(prob, kwargs)
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
Expand Down
Loading