-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Open
Labels
compiler:inferenceType inferenceType inference
Description
MWE, here's the setup code isolated to just ForwardDiff:
using ForwardDiff
const DUALCHECK_RECURSION_MAX = 10
"""
reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}
An optimized `reduce` for tuples. `Base.reduce`'s `afoldl` will often not inline.
Additionally, `reduce_tup` attempts to order the reduction in an optimal manner.
More importantly, `reduce_tup(_pick_range, inds)` often performs better than `reduce(_pick_range, inds)`.
"""
@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any, N}}) where {F, N}
q = Expr(:block, Expr(:meta, :inline, :propagate_inbounds))
if N == 1
push!(q.args, :(inds[1]))
return q
end
syms = Vector{Symbol}(undef, N)
i = 0
for n in 1:N
syms[n] = iₙ = Symbol(:i_, (i += 1))
push!(q.args, Expr(:(=), iₙ, Expr(:ref, :inds, n)))
end
W = 1 << (8sizeof(N) - 2 - leading_zeros(N))
while W > 0
_N = length(syms)
for _ in (2W):W:_N
for w in 1:W
new_sym = Symbol(:i_, (i += 1))
push!(q.args, Expr(:(=), new_sym, Expr(:call, :f, syms[w], syms[w + W])))
syms[w] = new_sym
end
deleteat!(syms, (1 + W):(2W))
end
W >>>= 1
end
q
end
"""
promote_dual(::Type{T},::Type{T2})
Is like the number promotion system, but always prefers a dual number type above
anything else. For higher order differentiation, it returns the most dualiest of
them all. This is then used to promote `u0` into the suspected highest differentiation
space for solving the equation.
"""
promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T
promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
function promote_dual(::Type{T},
::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual}
T
end
promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2
function promote_dual(::Type{T},
::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T4, V2, N2}}
T2
end
function promote_dual(::Type{T},
::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T4, V2, N2}}
T
end
function promote_dual(::Type{T},
::Type{T2}) where {
T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual,
N,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T3, V2, N}}
ForwardDiff.Dual{T3, promote_dual(V, V2), N}
end
# `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
# Therefore, they can be type stable even with heterogeneous input types.
# We also don't care about allocating any temporaries with them, as it should
# all be unrolled and optimized away.
# Being unrolled also means const prop can work for things like
# `mapreduce(f, op, propertynames(x))`
# where `f` may call `getproperty` and thus have return type dependent
# on the particular symbol.
# `mapreduce` hasn't received any such specialization.
@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x))
@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP}
reduce_tup(op, map(f, x))
end
# For other container types, we probably just want to call `mapreduce`
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x)
"""
anyeltypedual(x)
Searches through a type to see if any of its values are parameters. This is used to
then promote other values to match the dual type. For example, if a user passes a parameter
which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
to a dual number before the solve. Worse still, this needs to be done in the case of
`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
way to calculate the required state type.
But given the properties of automatic differentiation requiring that differentiation of parameters
implies differentiation of state, we assume any dual parameters implies differentiation of state
and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
upconversion is not done automatically, the user is required to upconvert all initial conditions
themselves, for an example of how this can be confusing to a user see
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
"""
function anyeltypedual(x, counter = 0)
if propertynames(x) === ()
Any
elseif counter < DUALCHECK_RECURSION_MAX
diffeqmapreduce(DualEltypeChecker(x, counter), promote_dual,
map(Val, propertynames(x)))
else
Any
end
end
# Opt out since these are using for preallocation, not differentiation
anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, counter = 0) = Any
anyeltypedual(x::Type{T}, counter = 0) where {T <: ForwardDiff.AbstractConfig} = Any
Base.@pure function __anyeltypedual(::Type{T}) where {T}
hasproperty(T, :parameters) ?
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T
end
anyeltypedual(::Type{T}, counter = 0) where {T} = __anyeltypedual(T)
anyeltypedual(::Type{T}, counter = 0) where {T <: ForwardDiff.Dual} = T
function anyeltypedual(::Type{T}, counter = 0) where {T <: Union{AbstractArray, Set}}
anyeltypedual(eltype(T))
end
Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
if isconcretetype(eltype(T))
return eltype(T)
end
if isempty(T.parameters)
Any
else
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
end
end
anyeltypedual(::Type{T}, counter = 0) where {T <: NTuple} = __anyeltypedual_ntuple(T)
# Any in this context just means not Dual
anyeltypedual(x::Number, counter = 0) = anyeltypedual(typeof(x))
anyeltypedual(x::Union{String, Symbol}, counter = 0) = typeof(x)
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {
T <:
Union{Number,
Symbol,
String}}
anyeltypedual(T)
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {
T <: Union{
AbstractArray{
<:Number,
},
Set{
<:Number,
}}}
anyeltypedual(eltype(x))
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {N, T <: NTuple{N, <:Number}}
anyeltypedual(eltype(x))
end
# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
function anyeltypedual(x::AbstractArray, counter = 0)
if isconcretetype(eltype(x))
anyeltypedual(eltype(x))
elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) &&
counter < DUALCHECK_RECURSION_MAX
counter += 1
mapreduce(y -> anyeltypedual(y, counter), promote_dual, x)
else
# This fallback to Any is required since otherwise we cannot handle `undef` in all cases
# misses cases of
Any
end
end
function anyeltypedual(x::Set, counter = 0)
if isconcretetype(eltype(x))
anyeltypedual(eltype(x))
else
# This fallback to Any is required since otherwise we cannot handle `undef` in all cases
Any
end
end
function anyeltypedual(x::Tuple, counter = 0)
# Handle the empty tuple case separately for inference and to avoid mapreduce error
if x === ()
Any
else
diffeqmapreduce(anyeltypedual, promote_dual, x)
end
end
function anyeltypedual(x::Dict, counter = 0)
isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x))
end
function anyeltypedual(x::NamedTuple, counter = 0)
isempty(x) ? Any : diffeqmapreduce(anyeltypedual, promote_dual, values(x))
end
@inline function promote_u0(u0, p, t0)
if !(eltype(u0) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
return T.(u0)
end
end
u0
end
@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
if !(real(eltype(u0)) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
Ts = promote_type(T, eltype(u0))
return Ts.(u0)
end
end
u0
end
struct DualEltypeChecker{T}
x::T
counter::Int
DualEltypeChecker(x::T, counter::Int) where {T} = new{T}(x, counter + 1)
end
function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
isdefined(dec.x, Y) || return Any
dec.counter >= DUALCHECK_RECURSION_MAX && return Any
anyeltypedual(getproperty(dec.x, Y), dec.counter)
end
# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
function (dec::DualEltypeChecker{<:Base.Pairs})(::Val{Y}) where {Y}
isdefined(dec.x, Y) || return Any
dec.counter >= DUALCHECK_RECURSION_MAX && return Any
anyeltypedual(getfield(dec.x, Y), dec.counter)
end
struct Thing
a::Float64
end
struct Wrapper1{T}
thing::T
end
struct Wrapper2{T}
thing::T
end
thing = Thing(1.0)
x = 1.0Now the checks:
promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper1(thing), (0.0, 1.0))MethodInstance for promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
#self#::Core.Const(promote_u0)
u0::Float64
p::Wrapper1{Thing}
t0::Tuple{Float64, Float64}
Locals
T::Any
Body::Any
1 ─ nothing
│ Core.NewvarNode(:(T))
│ %3 = Main.eltype(u0)::Core.Const(Float64)
│ %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│ %5 = (%3 <: %4)::Core.Const(false)
│ %6 = !%5::Core.Const(true)
└── goto #6 if not %6
2 ─ (T = Main.anyeltypedual(p))
│ %9 = (T === Main.Any)::Bool
└── goto #4 if not %9
3 ─ return u0
4 ─ %12 = T::Any
│ %13 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│ %14 = (%12 <: %13)::Bool
└── goto #6 if not %14
5 ─ %16 = Base.broadcasted(T, u0)::Base.Broadcast.Broadcasted{Style, Nothing} where Style<:Union{Nothing, Base.Broadcast.BroadcastStyle}
│ %17 = Base.materialize(%16)::Any
└── return %17
6 ┄ return u0
promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper2(thing), (0.0, 1.0))MethodInstance for promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
#self#::Core.Const(promote_u0)
u0::Float64
p::Wrapper2{Thing}
t0::Tuple{Float64, Float64}
Locals
T::Type{Any}
Body::Float64
1 ─ nothing
│ Core.NewvarNode(:(T))
│ %3 = Main.eltype(u0)::Core.Const(Float64)
│ %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│ %5 = (%3 <: %4)::Core.Const(false)
│ %6 = !%5::Core.Const(true)
└── goto #5 if not %6
2 ─ (T = Main.anyeltypedual(p))
│ %9 = (T::Core.Const(Any) === Main.Any)::Core.Const(true)
└── goto #4 if not %9
3 ─ return u0
4 ─ Core.Const(:(T))
│ Core.Const(:(ForwardDiff.Dual))
│ Core.Const(:(%12 <: %13))
│ Core.Const(:(goto %19 if not %14))
│ Core.Const(:(Base.broadcasted(T, u0)))
│ Core.Const(:(Base.materialize(%16)))
└── Core.Const(:(return %17))
5 ┄ Core.Const(:(return u0))
First reported as SciML/DiffEqBase.jl#918
Metadata
Metadata
Assignees
Labels
compiler:inferenceType inferenceType inference