diff --git a/Project.toml b/Project.toml index f5de9bbd..3318b46d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.22" +version = "0.10.23" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -18,8 +19,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Calculus = "0.2, 0.3, 0.4, 0.5" CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" -DiffRules = "1.2.1" +DiffRules = "1.4.0" DiffTests = "0.0.1, 0.1" +LogExpFunctions = "0.3" NaNMath = "0.2.2, 0.3" Preferences = "1" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0" diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index aacdbb8a..93d3b246 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -12,6 +12,7 @@ using LinearAlgebra import Printf import NaNMath import SpecialFunctions +import LogExpFunctions import CommonSubexpressions include("prelude.jl") diff --git a/src/dual.jl b/src/dual.jl index e9629fe1..d52c5736 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -392,8 +392,12 @@ Base.float(d::Dual) = convert(float(typeof(d)), d) # General Mathematical Operations # ################################### -for (M, f, arity) in DiffRules.diffrules() - in((M, f), ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-))) && continue +for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) + if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-)) + continue # Skip methods which we define elsewhere. + elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) + continue # Skip rules for methods not defined in the current scope + end if arity == 1 eval(unary_dual_definition(M, f)) elseif arity == 2 diff --git a/test/DualTest.jl b/test/DualTest.jl index cfdea005..4d3a5c32 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -6,7 +6,7 @@ using Random using ForwardDiff using ForwardDiff: Partials, Dual, value, partials -using NaNMath, SpecialFunctions +using NaNMath, SpecialFunctions, LogExpFunctions using DiffRules import Calculus @@ -420,12 +420,22 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test abs(NESTED_FDNUM) === NESTED_FDNUM if V != Int - for (M, f, arity) in DiffRules.diffrules() - in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue + for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) + if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi) + continue # Skip these rules + elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) + continue # Skip rules for methods not defined in the current scope + end println(" ...auto-testing $(M).$(f) with $arity arguments") if arity == 1 deriv = DiffRules.diffrule(M, f, :x) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? one(V) : zero(V) + modifier = if in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) + one(V) + elseif in(f, (:log1mexp, :log2mexp)) + -one(V) + else + zero(V) + end @eval begin x = rand() + $modifier dx = $M.$f(Dual{TestTag()}(x, one(x)))