From 864877006c5d295cf9c207ba4740aa375a800a4f Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 5 Nov 2021 10:36:08 +1300 Subject: [PATCH 1/4] Fix deprecation warning introduced by DiffRules v1.4 --- Project.toml | 4 ++-- src/dual.jl | 2 +- test/DualTest.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index f5de9bbd..21a3c2b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.22" +version = "0.10.23" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" @@ -18,7 +18,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Calculus = "0.2, 0.3, 0.4, 0.5" CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" -DiffRules = "1.2.1" +DiffRules = "1.4.0" DiffTests = "0.0.1, 0.1" NaNMath = "0.2.2, 0.3" Preferences = "1" diff --git a/src/dual.jl b/src/dual.jl index e9629fe1..74538e4a 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -392,7 +392,7 @@ Base.float(d::Dual) = convert(float(typeof(d)), d) # General Mathematical Operations # ################################### -for (M, f, arity) in DiffRules.diffrules() +for (M, f, arity) in DiffRules.diffrules(filtered_modules = (:Base, :SpecialFunctions, :NaNMath)) in((M, f), ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-))) && continue if arity == 1 eval(unary_dual_definition(M, f)) diff --git a/test/DualTest.jl b/test/DualTest.jl index cfdea005..67fbb22d 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -420,7 +420,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test abs(NESTED_FDNUM) === NESTED_FDNUM if V != Int - for (M, f, arity) in DiffRules.diffrules() + for (M, f, arity) in DiffRules.diffrules(filtered_modules = (:Base, :SpecialFunctions, :NaNMath)) in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue println(" ...auto-testing $(M).$(f) with $arity arguments") if arity == 1 From 402dbda50495244c955dad5a97266957457ee282 Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 5 Nov 2021 10:43:51 +1300 Subject: [PATCH 2/4] Fix typo --- src/dual.jl | 2 +- test/DualTest.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 74538e4a..3cf9aa32 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -392,7 +392,7 @@ Base.float(d::Dual) = convert(float(typeof(d)), d) # General Mathematical Operations # ################################### -for (M, f, arity) in DiffRules.diffrules(filtered_modules = (:Base, :SpecialFunctions, :NaNMath)) +for (M, f, arity) in DiffRules.diffrules(filter_modules = (:Base, :SpecialFunctions, :NaNMath)) in((M, f), ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-))) && continue if arity == 1 eval(unary_dual_definition(M, f)) diff --git a/test/DualTest.jl b/test/DualTest.jl index 67fbb22d..96bef598 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -420,7 +420,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test abs(NESTED_FDNUM) === NESTED_FDNUM if V != Int - for (M, f, arity) in DiffRules.diffrules(filtered_modules = (:Base, :SpecialFunctions, :NaNMath)) + for (M, f, arity) in DiffRules.diffrules(filter_modules = (:Base, :SpecialFunctions, :NaNMath)) in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue println(" ...auto-testing $(M).$(f) with $arity arguments") if arity == 1 From c742bdfa8a2e2d0511fdc54b5286414a16f8df2a Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 5 Nov 2021 12:22:02 +1300 Subject: [PATCH 3/4] Add rules for LogExpFunctions --- Project.toml | 2 ++ src/ForwardDiff.jl | 1 + src/dual.jl | 8 ++++++-- test/DualTest.jl | 18 ++++++++++++++---- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 21a3c2b4..95914a5c 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -20,6 +21,7 @@ CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "1.4.0" DiffTests = "0.0.1, 0.1" +LogExpFunctions = "0.3" NaNMath = "0.2.2, 0.3" Preferences = "1" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0" diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index aacdbb8a..93d3b246 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -12,6 +12,7 @@ using LinearAlgebra import Printf import NaNMath import SpecialFunctions +import LogExpFunctions import CommonSubexpressions include("prelude.jl") diff --git a/src/dual.jl b/src/dual.jl index 3cf9aa32..d52c5736 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -392,8 +392,12 @@ Base.float(d::Dual) = convert(float(typeof(d)), d) # General Mathematical Operations # ################################### -for (M, f, arity) in DiffRules.diffrules(filter_modules = (:Base, :SpecialFunctions, :NaNMath)) - in((M, f), ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-))) && continue +for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) + if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-)) + continue # Skip methods which we define elsewhere. + elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) + continue # Skip rules for methods not defined in the current scope + end if arity == 1 eval(unary_dual_definition(M, f)) elseif arity == 2 diff --git a/test/DualTest.jl b/test/DualTest.jl index 96bef598..4d3a5c32 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -6,7 +6,7 @@ using Random using ForwardDiff using ForwardDiff: Partials, Dual, value, partials -using NaNMath, SpecialFunctions +using NaNMath, SpecialFunctions, LogExpFunctions using DiffRules import Calculus @@ -420,12 +420,22 @@ for N in (0,3), M in (0,4), V in (Int, Float32) @test abs(NESTED_FDNUM) === NESTED_FDNUM if V != Int - for (M, f, arity) in DiffRules.diffrules(filter_modules = (:Base, :SpecialFunctions, :NaNMath)) - in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue + for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) + if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi) + continue # Skip these rules + elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) + continue # Skip rules for methods not defined in the current scope + end println(" ...auto-testing $(M).$(f) with $arity arguments") if arity == 1 deriv = DiffRules.diffrule(M, f, :x) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? one(V) : zero(V) + modifier = if in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) + one(V) + elseif in(f, (:log1mexp, :log2mexp)) + -one(V) + else + zero(V) + end @eval begin x = rand() + $modifier dx = $M.$f(Dual{TestTag()}(x, one(x))) From 7c857cb777f80a98c61f7c51fb924f1ad3f7c770 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 10 Nov 2021 09:42:10 +1300 Subject: [PATCH 4/4] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 95914a5c..3318b46d 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ CommonSubexpressions = "0.3" DiffResults = "0.0.1, 0.0.2, 0.0.3, 0.0.4, 1.0.1" DiffRules = "1.4.0" DiffTests = "0.0.1, 0.1" -LogExpFunctions = "0.3" +LogExpFunctions = "0.3" NaNMath = "0.2.2, 0.3" Preferences = "1" SpecialFunctions = "0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 1.0"