Skip to content

Non-deterministic inference results? #50735

@ChrisRackauckas

Description

@ChrisRackauckas

MWE, here's the setup code isolated to just ForwardDiff:

using ForwardDiff
const DUALCHECK_RECURSION_MAX = 10

"""
  reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}

An optimized `reduce` for tuples. `Base.reduce`'s `afoldl` will often not inline.
Additionally, `reduce_tup` attempts to order the reduction in an optimal manner.

More importantly, `reduce_tup(_pick_range, inds)` often performs better than `reduce(_pick_range, inds)`.
"""
@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any, N}}) where {F, N}
    q = Expr(:block, Expr(:meta, :inline, :propagate_inbounds))
    if N == 1
        push!(q.args, :(inds[1]))
        return q
    end
    syms = Vector{Symbol}(undef, N)
    i = 0
    for n in 1:N
        syms[n] = iₙ = Symbol(:i_, (i += 1))
        push!(q.args, Expr(:(=), iₙ, Expr(:ref, :inds, n)))
    end
    W = 1 << (8sizeof(N) - 2 - leading_zeros(N))
    while W > 0
        _N = length(syms)
        for _ in (2W):W:_N
            for w in 1:W
                new_sym = Symbol(:i_, (i += 1))
                push!(q.args, Expr(:(=), new_sym, Expr(:call, :f, syms[w], syms[w + W])))
                syms[w] = new_sym
            end
            deleteat!(syms, (1 + W):(2W))
        end
        W >>>= 1
    end
    q
end

"""
    promote_dual(::Type{T},::Type{T2})


Is like the number promotion system, but always prefers a dual number type above
anything else. For higher order differentiation, it returns the most dualiest of
them all. This is then used to promote `u0` into the suspected highest differentiation
space for solving the equation.
"""
promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T
promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
function promote_dual(::Type{T},
    ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual}
    T
end
promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2

function promote_dual(::Type{T},
    ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T4, V2, N2}}
    T2
end
function promote_dual(::Type{T},
    ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T4, V2, N2}}
    T
end
function promote_dual(::Type{T},
    ::Type{T2}) where {
    T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual,
    N,
    T <: ForwardDiff.Dual{T3, V, N},
    T2 <: ForwardDiff.Dual{T3, V2, N}}
    ForwardDiff.Dual{T3, promote_dual(V, V2), N}
end

# `reduce` and `map` are specialized on tuples to be unrolled (via recursion)
# Therefore, they can be type stable even with heterogeneous input types.
# We also don't care about allocating any temporaries with them, as it should
# all be unrolled and optimized away.
# Being unrolled also means const prop can work for things like
# `mapreduce(f, op, propertynames(x))`
# where `f` may call `getproperty` and thus have return type dependent
# on the particular symbol.
# `mapreduce` hasn't received any such specialization.
@inline diffeqmapreduce(f::F, op::OP, x::Tuple) where {F, OP} = reduce_tup(op, map(f, x))
@inline function diffeqmapreduce(f::F, op::OP, x::NamedTuple) where {F, OP}
    reduce_tup(op, map(f, x))
end
# For other container types, we probably just want to call `mapreduce`
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x)

"""
    anyeltypedual(x)


Searches through a type to see if any of its values are parameters. This is used to
then promote other values to match the dual type. For example, if a user passes a parameter

which is a `Dual` and a `u0` which is a `Float64`, after the first time step, `f(u,p,t) = p*u`
will change `u0` from `Float64` to `Dual`. Thus the state variable always needs to be converted
to a dual number before the solve. Worse still, this needs to be done in the case of
`f(du,u,p,t) = du[1] = p*u[1]`, and thus running `f` and taking the return value is not a valid
way to calculate the required state type.

But given the properties of automatic differentiation requiring that differentiation of parameters
implies differentiation of state, we assume any dual parameters implies differentiation of state
and then attempt to upconvert `u0` to match that dual-ness. Because this changes types, this needs
to be specified at compiled time and thus cannot have a Bool-based opt out, so in the future this
may be extended to use a preference system to opt-out with a `UPCONVERT_DUALS`. In the case where
upconversion is not done automatically, the user is required to upconvert all initial conditions
themselves, for an example of how this can be confusing to a user see
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
"""
function anyeltypedual(x, counter = 0)
    if propertynames(x) === ()
        Any
    elseif counter < DUALCHECK_RECURSION_MAX
        diffeqmapreduce(DualEltypeChecker(x, counter), promote_dual,
            map(Val, propertynames(x)))
    else
        Any
    end
end

# Opt out since these are using for preallocation, not differentiation
anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, counter = 0) = Any
anyeltypedual(x::Type{T}, counter = 0) where {T <: ForwardDiff.AbstractConfig} = Any

Base.@pure function __anyeltypedual(::Type{T}) where {T}
    hasproperty(T, :parameters) ?
    mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T
end
anyeltypedual(::Type{T}, counter = 0) where {T} = __anyeltypedual(T)
anyeltypedual(::Type{T}, counter = 0) where {T <: ForwardDiff.Dual} = T
function anyeltypedual(::Type{T}, counter = 0) where {T <: Union{AbstractArray, Set}}
    anyeltypedual(eltype(T))
end
Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
    if isconcretetype(eltype(T))
        return eltype(T)
    end
    if isempty(T.parameters)
        Any
    else
        mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
    end
end
anyeltypedual(::Type{T}, counter = 0) where {T <: NTuple} = __anyeltypedual_ntuple(T)

# Any in this context just means not Dual
anyeltypedual(x::Number, counter = 0) = anyeltypedual(typeof(x))
anyeltypedual(x::Union{String, Symbol}, counter = 0) = typeof(x)
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {
    T <:
    Union{Number,
        Symbol,
        String}}
    anyeltypedual(T)
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {
    T <: Union{
        AbstractArray{
            <:Number,
        },
        Set{
            <:Number,
        }}}
    anyeltypedual(eltype(x))
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
    counter = 0) where {N, T <: NTuple{N, <:Number}}
    anyeltypedual(eltype(x))
end

# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
function anyeltypedual(x::AbstractArray, counter = 0)
    if isconcretetype(eltype(x))
        anyeltypedual(eltype(x))
    elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) &&
           counter < DUALCHECK_RECURSION_MAX
        counter += 1
        mapreduce(y -> anyeltypedual(y, counter), promote_dual, x)
    else
        # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
        #  misses cases of
        Any
    end
end

function anyeltypedual(x::Set, counter = 0)
    if isconcretetype(eltype(x))
        anyeltypedual(eltype(x))
    else
        # This fallback to Any is required since otherwise we cannot handle `undef` in all cases
        Any
    end
end

function anyeltypedual(x::Tuple, counter = 0)
    # Handle the empty tuple case separately for inference and to avoid mapreduce error
    if x === ()
        Any
    else
        diffeqmapreduce(anyeltypedual, promote_dual, x)
    end
end
function anyeltypedual(x::Dict, counter = 0)
    isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x))
end
function anyeltypedual(x::NamedTuple, counter = 0)
    isempty(x) ? Any : diffeqmapreduce(anyeltypedual, promote_dual, values(x))
end
@inline function promote_u0(u0, p, t0)
    if !(eltype(u0) <: ForwardDiff.Dual)
        T = anyeltypedual(p)
        T === Any && return u0
        if T <: ForwardDiff.Dual
            return T.(u0)
        end
    end
    u0
end

@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
    if !(real(eltype(u0)) <: ForwardDiff.Dual)
        T = anyeltypedual(p)
        T === Any && return u0
        if T <: ForwardDiff.Dual
            Ts = promote_type(T, eltype(u0))
            return Ts.(u0)
        end
    end
    u0
end

struct DualEltypeChecker{T}
    x::T
    counter::Int
    DualEltypeChecker(x::T, counter::Int) where {T} = new{T}(x, counter + 1)
end

function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
    isdefined(dec.x, Y) || return Any
    dec.counter >= DUALCHECK_RECURSION_MAX && return Any
    anyeltypedual(getproperty(dec.x, Y), dec.counter)
end

# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
function (dec::DualEltypeChecker{<:Base.Pairs})(::Val{Y}) where {Y}
    isdefined(dec.x, Y) || return Any
    dec.counter >= DUALCHECK_RECURSION_MAX && return Any
    anyeltypedual(getfield(dec.x, Y), dec.counter)
end

struct Thing
    a::Float64
end
struct Wrapper1{T}
    thing::T
end
struct Wrapper2{T}
    thing::T
end

thing = Thing(1.0)
x = 1.0

Now the checks:

promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper1(thing), (0.0, 1.0))
MethodInstance for promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Any
Body::Any
1 ─       nothing
│         Core.NewvarNode(:(T))
│   %3  = Main.eltype(u0)::Core.Const(Float64)
│   %4  = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5  = (%3 <: %4)::Core.Const(false)
│   %6  = !%5::Core.Const(true)
└──       goto #6 if not %6
2 ─       (T = Main.anyeltypedual(p))
│   %9  = (T === Main.Any)::Bool
└──       goto #4 if not %9
3 ─       return u0
4 ─ %12 = T::Any
│   %13 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %14 = (%12 <: %13)::Bool
└──       goto #6 if not %14
5 ─ %16 = Base.broadcasted(T, u0)::Base.Broadcast.Broadcasted{Style, Nothing} where Style<:Union{Nothing, Base.Broadcast.BroadcastStyle}
│   %17 = Base.materialize(%16)::Any
└──       return %17
6 ┄       return u0
promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype promote_u0(x, Wrapper2(thing), (0.0, 1.0))
MethodInstance for promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ Main c:\Users\accou\OneDrive\Computer\Desktop\test.jl:346
Arguments
  #self#::Core.Const(promote_u0)
  u0::Float64
  p::Wrapper2{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Type{Any}
Body::Float64
1 ─      nothing
│        Core.NewvarNode(:(T))
│   %3 = Main.eltype(u0)::Core.Const(Float64)
│   %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5 = (%3 <: %4)::Core.Const(false)
│   %6 = !%5::Core.Const(true)
└──      goto #5 if not %6
2 ─      (T = Main.anyeltypedual(p))
│   %9 = (T::Core.Const(Any) === Main.Any)::Core.Const(true)
└──      goto #4 if not %9
3 ─      return u0
4 ─      Core.Const(:(T))
│        Core.Const(:(ForwardDiff.Dual))
│        Core.Const(:(%12 <: %13))
│        Core.Const(:(goto %19 if not %14))
│        Core.Const(:(Base.broadcasted(T, u0)))
│        Core.Const(:(Base.materialize(%16)))
└──      Core.Const(:(return %17))
5 ┄      Core.Const(:(return u0))

First reported as SciML/DiffEqBase.jl#918

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions