Skip to content

Weird inference failure in promote_u0 #918

@fjebaker

Description

@fjebaker

Originally opened in SciML/OrdinaryDiffEq.jl#2001

MWE:

using DiffEqBase

struct Thing
    a::Float64
end
struct Wrapper1{T}
    thing::T
end
struct Wrapper2{T}
    thing::T
end

thing = Thing(1.0)
x = 1.0

DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0))
@code_warntype DiffEqBase.promote_u0(x, Wrapper1(thing), (0.0, 1.0))

DiffEqBase.promote_u0(x, Wrapper2(thing), (0.0, 1.0))
@code_warntype DiffEqBase.promote_u0(x, Wrapper2(thing), (0.0, 1.0))

Output:

MethodInstance for DiffEqBase.promote_u0(::Float64, ::Wrapper1{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:208
Arguments
  #self#::Core.Const(DiffEqBase.promote_u0)
  u0::Float64
  p::Wrapper1{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Any
Body::Any
1nothing
│         Core.NewvarNode(:(T))
│   %3  = DiffEqBase.eltype(u0)::Core.Const(Float64)
│   %4  = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5  = (%3 <: %4)::Core.Const(false)
│   %6  = !%5::Core.Const(true)
└──       goto #6 if not %6
2 ─       (T = DiffEqBase.anyeltypedual(p))
│   %9  = (T === DiffEqBase.Any)::Bool
└──       goto #4 if not %9
3return u0
4%12 = T::Any%13 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %14 = (%12 <: %13)::Bool
└──       goto #6 if not %14
5%16 = Base.broadcasted(T, u0)::Any%17 = Base.materialize(%16)::Any
└──       return %17
6return u0

MethodInstance for DiffEqBase.promote_u0(::Float64, ::Wrapper2{Thing}, ::Tuple{Float64, Float64})
  from promote_u0(u0, p, t0) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:208
Arguments
  #self#::Core.Const(DiffEqBase.promote_u0)
  u0::Float64
  p::Wrapper2{Thing}
  t0::Tuple{Float64, Float64}
Locals
  T::Type{Any}
Body::Float64
1nothing
│        Core.NewvarNode(:(T))
│   %3 = DiffEqBase.eltype(u0)::Core.Const(Float64)
│   %4 = ForwardDiff.Dual::Core.Const(ForwardDiff.Dual)
│   %5 = (%3 <: %4)::Core.Const(false)
│   %6 = !%5::Core.Const(true)
└──      goto #5 if not %6
2 ─      (T = DiffEqBase.anyeltypedual(p))
│   %9 = (T::Core.Const(Any) === DiffEqBase.Any)::Core.Const(true)
└──      goto #4 if not %9
3return u0
4 ─      Core.Const(:(T))
│        Core.Const(:(ForwardDiff.Dual))
│        Core.Const(:(%12 <: %13))
│        Core.Const(:(goto %19 if not %14))
│        Core.Const(:(Base.broadcasted(T, u0)))
│        Core.Const(:(Base.materialize(%16)))
└──      Core.Const(:(return %17))
5 ┄      Core.Const(:(return u0))

The first one always fails to infer (i.e. doing Wrapper2 first makes it fail to infer Wrapper2).

I have no idea why, since anyeltypedual seems to be const folded:

@code_warntype DiffEqBase.anyeltypedual(Wrapper1(thing))
@code_warntype DiffEqBase.anyeltypedual(Wrapper2(thing))
MethodInstance for DiffEqBase.anyeltypedual(::Wrapper1{Thing})
  from anyeltypedual(x) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:101
Arguments
  #self#::Core.Const(DiffEqBase.anyeltypedual)
  x::Wrapper1{Thing}
Body::Type{Any}
1%1 = (#self#)(x, 0)::Core.Const(Any)
└──      return %1

MethodInstance for DiffEqBase.anyeltypedual(::Wrapper2{Thing})
  from anyeltypedual(x) @ DiffEqBase ~/.julia/packages/DiffEqBase/9rTlH/src/forwarddiff.jl:101
Arguments
  #self#::Core.Const(DiffEqBase.anyeltypedual)
  x::Wrapper2{Thing}
Body::Type{Any}
1%1 = (#self#)(x, 0)::Core.Const(Any)
└──      return %1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions