Skip to content

Commit 773be5a

Browse files
torfjeldedevmotion
andauthored
Fixes return-types for multiple rules (#55)
* fixed issues with return-type for several rules * initial work on inferring return type using intermediate computations * removed oftype where possible * fixed stupid mistake * missed one in fix of stupid mistake * more fixes * added tests for different types * removed a couple of overdone oftypes * moved a single paranthesis * use irrationals to simplify type-promotion * depend on IrrationalConstants.jl * fixed tests * actually fixed tests * updated rules to use constants from IrrationalConstants * Update src/rules.jl Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * reverted some changes from previous commit * fixed a typo * fixed typo * fixed type-conversion in tests * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * reverse previous commit due to JuliaMath/NaNMath.jl#47 * reverse previous commit due to JuliaMath/NaNMath.jl#47 * drop qualifications from rules * reverted unintended change to _abs_deriv * interpolate _abs_deriv * be explicit about imported irrationals * fixed tests * fixed numerical issues in tests by adopting some changes from #79 * relax rtol slightly since we're working with Float32 too here * Update Project.toml * test type of derivative for functions with 2 arguments * fixed types of derivatives for mod, rem and different bessel functions * use deg2rad Co-authored-by: David Widmann <[email protected]> * reverted changes to + and - * remove duplicate rules * add back whitespace * reverted changes to bessel functions * only test return-type having the correct promotion behavior * only test type for 2 argument functions whose derivatives aren't NaN * fixed rules of mod and rem * make each rule its own testset for easier debugging * reverted changes to multiple NaNMath rules * use more explicit promotion in tests Co-authored-by: David Widmann <[email protected]> * check promotion of real instead of specific check for Complex Co-authored-by: David Widmann <[email protected]> * reverted unnecessary change * reverted unnecessary changes * dont check if AbstractFloat in tests Co-authored-by: David Widmann <[email protected]>
1 parent 7b1c31e commit 773be5a

File tree

4 files changed

+124
-94
lines changed

4 files changed

+124
-94
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
name = "DiffRules"
22
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
3-
version = "1.9.0"
3+
version = "1.9.1"
44

55
[deps]
6+
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
67
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
78
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1011

1112
[compat]
13+
IrrationalConstants = "0.1.1"
1214
LogExpFunctions = "0.3.2"
1315
NaNMath = "0.3"
1416
SpecialFunctions = "0.10, 1.0, 2"
1517
julia = "1.3"
1618

1719
[extras]
20+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1821
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1922
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2023

2124
[targets]
22-
test = ["Test", "Random"]
25+
test = ["Test", "Random", "FiniteDifferences"]

src/DiffRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ __precompile__()
22

33
module DiffRules
44

5+
using IrrationalConstants: logtwo, logten, twoπ, sqrtπ, invsqrtπ
6+
57
include("api.jl")
68
include("rules.jl")
79

src/rules.jl

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,46 @@
55
# unary #
66
#-------#
77

8-
@define_diffrule Base.:+(x) = :( 1 )
9-
@define_diffrule Base.:-(x) = :( -1 )
8+
@define_diffrule Base.:+(x) = :( 1 )
9+
@define_diffrule Base.:-(x) = :( -1 )
1010
@define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
1111
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
1212
@define_diffrule Base.abs2(x) = :( $x + $x )
1313
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
1414
@define_diffrule Base.log(x) = :( inv($x) )
15-
@define_diffrule Base.log10(x) = :( inv($x) / log(10) )
16-
@define_diffrule Base.log2(x) = :( inv($x) / log(2) )
15+
@define_diffrule Base.log10(x) = :( inv($x) / $logten )
16+
@define_diffrule Base.log2(x) = :( inv($x) / $logtwo )
1717
@define_diffrule Base.log1p(x) = :( inv($x + 1) )
1818
@define_diffrule Base.exp(x) = :( exp($x) )
19-
@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
20-
@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
19+
@define_diffrule Base.exp2(x) = :( exp2($x) * $logtwo )
20+
@define_diffrule Base.exp10(x) = :( exp10($x) * $logten )
2121
@define_diffrule Base.expm1(x) = :( exp($x) )
2222
@define_diffrule Base.sin(x) = :( cos($x) )
2323
@define_diffrule Base.cos(x) = :( -sin($x) )
2424
@define_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
2525
@define_diffrule Base.sec(x) = :( sec($x) * tan($x) )
2626
@define_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
2727
@define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
28-
@define_diffrule Base.sind(x) = :( / 180) * cosd($x) )
29-
@define_diffrule Base.cosd(x) = :( -/ 180) * sind($x) )
30-
@define_diffrule Base.tand(x) = :( / 180) * (1 + tand($x)^2) )
31-
@define_diffrule Base.secd(x) = :( / 180) * secd($x) * tand($x) )
32-
@define_diffrule Base.cscd(x) = :( -/ 180) * cscd($x) * cotd($x) )
33-
@define_diffrule Base.cotd(x) = :( -/ 180) * (1 + cotd($x)^2) )
28+
@define_diffrule Base.sind(x) = :( deg2rad(cosd($x)) )
29+
@define_diffrule Base.cosd(x) = :( - deg2rad(sind($x)) )
30+
@define_diffrule Base.tand(x) = :( deg2rad(1 + tand($x)^2) )
31+
@define_diffrule Base.secd(x) = :( deg2rad(secd($x) * tand($x)) )
32+
@define_diffrule Base.cscd(x) = :( - deg2rad(cscd($x) * cotd($x)) )
33+
@define_diffrule Base.cotd(x) = :( - deg2rad(1 + cotd($x)^2) )
3434
@define_diffrule Base.sinpi(x) = :( π * cospi($x) )
35-
@define_diffrule Base.cospi(x) = :( -π * sinpi($x) )
35+
@define_diffrule Base.cospi(x) = :( -(π * sinpi($x)) )
3636
@define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
3737
@define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
3838
@define_diffrule Base.atan(x) = :( inv(1 + $x^2) )
3939
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
4040
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
4141
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
42-
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
43-
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
44-
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
45-
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
46-
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
47-
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
42+
@define_diffrule Base.asind(x) = :( inv(deg2rad(sqrt(1 - $x^2))) )
43+
@define_diffrule Base.acosd(x) = :( -inv(deg2rad(sqrt(1 - $x^2))) )
44+
@define_diffrule Base.atand(x) = :( inv(deg2rad(1 + $x^2)) )
45+
@define_diffrule Base.asecd(x) = :( inv(deg2rad(abs($x) * sqrt($x^2 - 1))) )
46+
@define_diffrule Base.acscd(x) = :( -inv(deg2rad(abs($x) * sqrt($x^2 - 1))) )
47+
@define_diffrule Base.acotd(x) = :( -inv(deg2rad(1 + $x^2)) )
4848
@define_diffrule Base.sinh(x) = :( cosh($x) )
4949
@define_diffrule Base.cosh(x) = :( sinh($x) )
5050
@define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
@@ -58,16 +58,15 @@
5858
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
5959
@define_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
6060
@define_diffrule Base.sinc(x) = :( cosc($x) )
61-
@define_diffrule Base.deg2rad(x) = :( π / 180 )
62-
@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 )
63-
@define_diffrule Base.rad2deg(x) = :( 180 / π )
64-
61+
@define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) )
62+
@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) )
63+
@define_diffrule Base.rad2deg(x) = :( rad2deg(one($x)) )
6564
@define_diffrule SpecialFunctions.gamma(x) =
6665
:( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) )
6766
@define_diffrule SpecialFunctions.loggamma(x) =
6867
:( SpecialFunctions.digamma($x) )
6968

70-
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )
69+
@define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) )
7170

7271
# We provide this hook for special number types like `Interval`
7372
# that need their own special definition of `abs`.
@@ -88,8 +87,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
8887
@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) )
8988
@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN
9089

91-
@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) )
92-
@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) )
90+
@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) )
91+
@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) )
9392
@define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN
9493
@define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) )
9594
@define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) )
@@ -113,17 +112,17 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
113112
# unary #
114113
#-------#
115114

116-
@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
115+
@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) )
117116
@define_diffrule SpecialFunctions.erfinv(x) =
118-
:( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) )
119-
@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
117+
:( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 )
118+
@define_diffrule SpecialFunctions.erfc(x) = :( -($invsqrtπ * exp(-$x^2) * 2) )
120119
@define_diffrule SpecialFunctions.erfcinv(x) =
121-
:( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) )
122-
@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
120+
:( -($sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 )
121+
@define_diffrule SpecialFunctions.erfi(x) = :( $invsqrtπ * exp($x^2) * 2 )
123122
@define_diffrule SpecialFunctions.erfcx(x) =
124-
:( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) )
123+
:( 2 * (($x * SpecialFunctions.erfcx($x)) - $invsqrtπ) )
125124
@define_diffrule SpecialFunctions.logerfcx(x) =
126-
:( 2 * ($x - inv(SpecialFunctions.erfcx($x) * sqrt(π))) )
125+
:( 2 * ($x - inv(SpecialFunctions.erfcx($x) * $sqrtπ)) )
127126
@define_diffrule SpecialFunctions.dawson(x) =
128127
:( 1 - (2 * $x * SpecialFunctions.dawson($x)) )
129128
@define_diffrule SpecialFunctions.digamma(x) =
@@ -217,8 +216,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
217216
@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) )
218217
@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) )
219218
@define_diffrule NaNMath.log(x) = :( inv($x) )
220-
@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(2) )
221-
@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(10) )
219+
@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) )
220+
@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) )
222221
@define_diffrule NaNMath.log1p(x) = :( inv($x + 1) )
223222
@define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) )
224223

test/runtests.jl

Lines changed: 83 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,108 @@
11
using DiffRules
22
using Test
3+
using FiniteDifferences
4+
5+
using IrrationalConstants: fourπ
36

47
import SpecialFunctions, NaNMath, LogExpFunctions
58
import Random
69
Random.seed!(1)
710

8-
function finitediff(f, x)
9-
ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x))
10-
return (f(x + ϵ) - f(x - ϵ)) /+ ϵ)
11-
end
11+
# Set `max_range` to avoid domain errors.
12+
const finitediff = central_fdm(5, 1, max_range=1e-3)
1213

1314
@testset "DiffRules" begin
1415
@testset "check rules" begin
1516

1617
non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, :ifelse, 3)]
1718

18-
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
19-
(M, f, arity) non_diffeable_arg_functions && continue
20-
if arity == 1
21-
@test DiffRules.hasdiffrule(M, f, 1)
22-
deriv = DiffRules.diffrule(M, f, :goo)
23-
modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)
24-
1.0
25-
elseif f === :log1mexp
26-
-1.0
27-
elseif f === :log2mexp
28-
-0.5
29-
else
30-
0.0
31-
end
32-
@eval begin
33-
let
34-
goo = rand() + $modifier
35-
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
36-
# test for 2pi functions
37-
if "mod2pi" == string($M.$f)
38-
goo = 4pi + $modifier
39-
@test NaN === $deriv
19+
@testset "($M, $f, $arity)" for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
20+
for T in [Float32, Float64]
21+
(M, f, arity) non_diffeable_arg_functions && continue
22+
if arity == 1
23+
@test DiffRules.hasdiffrule(M, f, 1)
24+
deriv = DiffRules.diffrule(M, f, :goo)
25+
@eval begin
26+
let
27+
goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth))
28+
# avoid singularities with finite differencing
29+
rand($T) + $T(1.5)
30+
elseif $(f in (:log, :airyaix, :airyaiprimex))
31+
# avoid singularities with finite differencing
32+
rand($T) + $T(0.5)
33+
elseif $(f === :log1mexp)
34+
rand($T) - one($T)
35+
elseif $(f in (:log2mexp, :erfinv))
36+
rand($T) - $T(0.5)
37+
else
38+
rand($T)
39+
end
40+
# We're happy with types with the correct promotion behavior, e.g.
41+
# it's fine to return `1` as a derivative despite input being `Float64`.
42+
@test promote_type(typeof($deriv), $T) === $T
43+
@test $deriv finitediff($M.$f, goo) rtol=1e-2 atol=1e-3
44+
# test for 2pi functions
45+
if $(f === :mod2pi)
46+
goo = 4 * pi
47+
@test NaN === $deriv
48+
end
4049
end
4150
end
42-
end
43-
elseif arity == 2
44-
@test DiffRules.hasdiffrule(M, f, 2)
45-
derivs = DiffRules.diffrule(M, f, :foo, :bar)
46-
@eval begin
47-
let
48-
if "mod" == string($M.$f)
49-
foo, bar = rand() + 13, rand() + 5 # make sure x/y is not integer
50-
else
51-
foo, bar = rand(1:10), rand()
51+
elseif arity == 2
52+
@test DiffRules.hasdiffrule(M, f, 2)
53+
derivs = DiffRules.diffrule(M, f, :foo, :bar)
54+
@eval begin
55+
let
56+
foo, bar = if $(f === :mod)
57+
rand($T) + 13, rand($T) + 5 # make sure x/y is not integer
58+
elseif $(f === :polygamma)
59+
rand(1:10), rand($T) # only supports integers as first arguments
60+
elseif $(f in (:bessely, :besselyx))
61+
# avoid singularities with finite differencing
62+
rand($T), rand($T) + $T(0.5)
63+
elseif $(f === :log)
64+
# avoid singularities with finite differencing
65+
rand($T) + $T(1.5), rand($T)
66+
elseif $(f === :^)
67+
# avoid singularities with finite differencing
68+
rand($T) + $T(0.5), rand($T)
69+
else
70+
rand($T), rand($T)
71+
end
72+
dx, dy = $(derivs[1]), $(derivs[2])
73+
if !(isnan(dx))
74+
@test dx finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3
75+
76+
# Check type, if applicable.
77+
@test promote_type(typeof(real(dx)), $T) === $T
78+
end
79+
if !(isnan(dy))
80+
@test dy finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3
81+
82+
# Check type, if applicable.
83+
@test promote_type(typeof(real(dy)), $T) === $T
84+
end
5285
end
53-
dx, dy = $(derivs[1]), $(derivs[2])
86+
end
87+
elseif arity == 3
88+
#=
89+
@test DiffRules.hasdiffrule(M, f, 3)
90+
derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo)
91+
@eval begin
92+
foo, bar, goo = randn(3)
93+
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
5494
if !(isnan(dx))
55-
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
95+
@test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05)
5696
end
5797
if !(isnan(dy))
58-
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
98+
@test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05)
99+
end
100+
if !(isnan(dz))
101+
@test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05)
59102
end
60103
end
104+
=#
61105
end
62-
elseif arity == 3
63-
#=
64-
@test DiffRules.hasdiffrule(M, f, 3)
65-
derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo)
66-
@eval begin
67-
foo, bar, goo = randn(3)
68-
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
69-
if !(isnan(dx))
70-
@test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05)
71-
end
72-
if !(isnan(dy))
73-
@test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05)
74-
end
75-
if !(isnan(dz))
76-
@test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05)
77-
end
78-
end
79-
=#
80106
end
81107
end
82108

0 commit comments

Comments
 (0)