-
Notifications
You must be signed in to change notification settings - Fork 104
Closed
Description
Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to Float64. E.g. for erfc we will first call sqrt(π) which results in Float64, and instead of promoting Irrational to what we expected the output-type to be, we end up promoting the output-type to Float64 (if we're using floats with lower precision):
julia> using SpecialFunctions, ChainRulesCore
julia> y, ȳ = ChainRulesCore.frule((ChainRulesCore.NO_FIELDS, 1f0), SpecialFunctions.erfc, 1f0)
(0.1572992f0, -0.41510750774498784)
julia> typeof(y), typeof(ȳ)
(Float32, Float64)This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).
Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?
Metadata
Metadata
Assignees
Labels
No labels