Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2abc4cc
fixed issues with return-type for several rules
torfjelde Jan 17, 2021
184ba54
initial work on inferring return type using intermediate computations
torfjelde Mar 11, 2021
c60fd7b
removed oftype where possible
torfjelde Apr 21, 2021
b440cdd
fixed stupid mistake
torfjelde Apr 21, 2021
65c3559
missed one in fix of stupid mistake
torfjelde Apr 21, 2021
dcba293
more fixes
torfjelde Apr 21, 2021
3f71077
added tests for different types
torfjelde Apr 21, 2021
ee42f7f
removed a couple of overdone oftypes
torfjelde Apr 21, 2021
548dc2c
moved a single paranthesis
torfjelde Apr 21, 2021
ee6fe3e
use irrationals to simplify type-promotion
torfjelde Jun 20, 2021
b149ba5
depend on IrrationalConstants.jl
torfjelde Jul 16, 2021
743b643
Merge branch 'master' into tor/return-type-fix
torfjelde Aug 11, 2021
38b7fd4
fixed tests
torfjelde Aug 11, 2021
6618cb1
actually fixed tests
torfjelde Aug 11, 2021
09fbd51
Merge branch 'master' into tor/return-type-fix
torfjelde Dec 13, 2021
b89221a
updated rules to use constants from IrrationalConstants
torfjelde Dec 13, 2021
a48a523
Update src/rules.jl
torfjelde Dec 14, 2021
1aaec9f
Apply suggestions from code review
torfjelde Dec 14, 2021
c8e6443
reverted some changes from previous commit
torfjelde Dec 14, 2021
34892e0
fixed a typo
torfjelde Dec 14, 2021
97445d4
fixed typo
torfjelde Dec 14, 2021
2dee9c3
fixed type-conversion in tests
torfjelde Dec 14, 2021
0f5b52c
Apply suggestions from code review
torfjelde Dec 14, 2021
a0b40e1
reverse previous commit due to https://github.com/mlubin/NaNMath.jl/i…
torfjelde Dec 14, 2021
746a157
reverse previous commit due to https://github.com/mlubin/NaNMath.jl/i…
torfjelde Dec 14, 2021
164ad73
drop qualifications from rules
torfjelde Dec 14, 2021
75b0ff6
reverted unintended change to _abs_deriv
torfjelde Dec 14, 2021
34ee3ca
interpolate _abs_deriv
torfjelde Dec 14, 2021
a58951c
be explicit about imported irrationals
torfjelde Dec 14, 2021
e64f9e0
fixed tests
torfjelde Dec 14, 2021
e383cfd
fixed numerical issues in tests by adopting some changes from #79
torfjelde Jan 15, 2022
4903915
relax rtol slightly since we're working with Float32 too here
torfjelde Jan 15, 2022
0111861
Update Project.toml
torfjelde Jan 27, 2022
657c0d5
test type of derivative for functions with 2 arguments
torfjelde Jan 27, 2022
f236d23
fixed types of derivatives for mod, rem and different bessel functions
torfjelde Jan 27, 2022
40e9133
use deg2rad
torfjelde Jan 27, 2022
e457830
reverted changes to + and -
torfjelde Jan 27, 2022
56dba97
remove duplicate rules
torfjelde Jan 27, 2022
80ded04
add back whitespace
torfjelde Jan 27, 2022
7149dff
reverted changes to bessel functions
torfjelde Jan 27, 2022
9823897
only test return-type having the correct promotion behavior
torfjelde Jan 27, 2022
97117ef
only test type for 2 argument functions whose derivatives aren't NaN
torfjelde Jan 27, 2022
e29056a
fixed rules of mod and rem
torfjelde Jan 27, 2022
773a572
make each rule its own testset for easier debugging
torfjelde Jan 27, 2022
cdbec02
reverted changes to multiple NaNMath rules
torfjelde Jan 28, 2022
6ddff44
use more explicit promotion in tests
torfjelde Jan 28, 2022
777d1a7
check promotion of real instead of specific check for Complex
torfjelde Jan 28, 2022
f937593
reverted unnecessary change
torfjelde Jan 28, 2022
5d7b635
reverted unnecessary changes
torfjelde Jan 28, 2022
6e564bc
dont check if AbstractFloat in tests
torfjelde Jan 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
name = "DiffRules"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.9.0"
version = "1.9.1"

[deps]
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
IrrationalConstants = "0.1.1"
LogExpFunctions = "0.3.2"
NaNMath = "0.3"
SpecialFunctions = "0.10, 1.0, 2"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "FiniteDifferences"]
2 changes: 2 additions & 0 deletions src/DiffRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ __precompile__()

module DiffRules

using IrrationalConstants: logtwo, logten, twoπ, sqrtπ, invsqrtπ

include("api.jl")
include("rules.jl")

Expand Down
69 changes: 34 additions & 35 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,46 @@
# unary #
#-------#

@define_diffrule Base.:+(x) = :( 1 )
@define_diffrule Base.:-(x) = :( -1 )
@define_diffrule Base.:+(x) = :( 1 )
@define_diffrule Base.:-(x) = :( -1 )
@define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
@define_diffrule Base.abs2(x) = :( $x + $x )
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
@define_diffrule Base.log(x) = :( inv($x) )
@define_diffrule Base.log10(x) = :( inv($x) / log(10) )
@define_diffrule Base.log2(x) = :( inv($x) / log(2) )
@define_diffrule Base.log10(x) = :( inv($x) / $logten )
@define_diffrule Base.log2(x) = :( inv($x) / $logtwo )
@define_diffrule Base.log1p(x) = :( inv($x + 1) )
@define_diffrule Base.exp(x) = :( exp($x) )
@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
@define_diffrule Base.exp2(x) = :( exp2($x) * $logtwo )
@define_diffrule Base.exp10(x) = :( exp10($x) * $logten )
@define_diffrule Base.expm1(x) = :( exp($x) )
@define_diffrule Base.sin(x) = :( cos($x) )
@define_diffrule Base.cos(x) = :( -sin($x) )
@define_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
@define_diffrule Base.sec(x) = :( sec($x) * tan($x) )
@define_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
@define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) )
@define_diffrule Base.cosd(x) = :( -(π / 180) * sind($x) )
@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) )
@define_diffrule Base.cscd(x) = :( -(π / 180) * cscd($x) * cotd($x) )
@define_diffrule Base.cotd(x) = :( -(π / 180) * (1 + cotd($x)^2) )
@define_diffrule Base.sind(x) = :( deg2rad(cosd($x)) )
@define_diffrule Base.cosd(x) = :( - deg2rad(sind($x)) )
@define_diffrule Base.tand(x) = :( deg2rad(1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( deg2rad(secd($x) * tand($x)) )
@define_diffrule Base.cscd(x) = :( - deg2rad(cscd($x) * cotd($x)) )
@define_diffrule Base.cotd(x) = :( - deg2rad(1 + cotd($x)^2) )
@define_diffrule Base.sinpi(x) = :( π * cospi($x) )
@define_diffrule Base.cospi(x) = :( -π * sinpi($x) )
@define_diffrule Base.cospi(x) = :( -(π * sinpi($x)) )
@define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
@define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
@define_diffrule Base.atan(x) = :( inv(1 + $x^2) )
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
@define_diffrule Base.asind(x) = :( inv(deg2rad(sqrt(1 - $x^2))) )
@define_diffrule Base.acosd(x) = :( -inv(deg2rad(sqrt(1 - $x^2))) )
@define_diffrule Base.atand(x) = :( inv(deg2rad(1 + $x^2)) )
@define_diffrule Base.asecd(x) = :( inv(deg2rad(abs($x) * sqrt($x^2 - 1))) )
@define_diffrule Base.acscd(x) = :( -inv(deg2rad(abs($x) * sqrt($x^2 - 1))) )
@define_diffrule Base.acotd(x) = :( -inv(deg2rad(1 + $x^2)) )
@define_diffrule Base.sinh(x) = :( cosh($x) )
@define_diffrule Base.cosh(x) = :( sinh($x) )
@define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
Expand All @@ -58,16 +58,15 @@
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
@define_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
@define_diffrule Base.sinc(x) = :( cosc($x) )
@define_diffrule Base.deg2rad(x) = :( π / 180 )
@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 )
@define_diffrule Base.rad2deg(x) = :( 180 / π )

@define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) )
@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) )
@define_diffrule Base.rad2deg(x) = :( rad2deg(one($x)) )
@define_diffrule SpecialFunctions.gamma(x) =
:( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) )
@define_diffrule SpecialFunctions.loggamma(x) =
:( SpecialFunctions.digamma($x) )

@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )
@define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) )

# We provide this hook for special number types like `Interval`
# that need their own special definition of `abs`.
Expand All @@ -88,8 +87,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) )
@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN

@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) )
@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) )
@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) )
@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) )
@define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN
@define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) )
@define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) )
Expand All @@ -113,17 +112,17 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
# unary #
#-------#

@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) )
@define_diffrule SpecialFunctions.erfinv(x) =
:( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) )
@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
:( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 )
@define_diffrule SpecialFunctions.erfc(x) = :( -($invsqrtπ * exp(-$x^2) * 2) )
@define_diffrule SpecialFunctions.erfcinv(x) =
:( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) )
@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
:( -($sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 )
@define_diffrule SpecialFunctions.erfi(x) = :( $invsqrtπ * exp($x^2) * 2 )
@define_diffrule SpecialFunctions.erfcx(x) =
:( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) )
:( 2 * (($x * SpecialFunctions.erfcx($x)) - $invsqrtπ) )
@define_diffrule SpecialFunctions.logerfcx(x) =
:( 2 * ($x - inv(SpecialFunctions.erfcx($x) * sqrt(π))) )
:( 2 * ($x - inv(SpecialFunctions.erfcx($x) * $sqrtπ)) )
@define_diffrule SpecialFunctions.dawson(x) =
:( 1 - (2 * $x * SpecialFunctions.dawson($x)) )
@define_diffrule SpecialFunctions.digamma(x) =
Expand Down Expand Up @@ -217,8 +216,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) )
@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) )
@define_diffrule NaNMath.log(x) = :( inv($x) )
@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(2) )
@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(10) )
@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) )
@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) )
@define_diffrule NaNMath.log1p(x) = :( inv($x + 1) )
@define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) )

Expand Down
140 changes: 83 additions & 57 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,108 @@
using DiffRules
using Test
using FiniteDifferences

using IrrationalConstants: fourπ

import SpecialFunctions, NaNMath, LogExpFunctions
import Random
Random.seed!(1)

function finitediff(f, x)
ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x))
return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ)
end
# Set `max_range` to avoid domain errors.
const finitediff = central_fdm(5, 1, max_range=1e-3)

@testset "DiffRules" begin
@testset "check rules" begin

non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, :ifelse, 3)]

for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
(M, f, arity) ∈ non_diffeable_arg_functions && continue
if arity == 1
@test DiffRules.hasdiffrule(M, f, 1)
deriv = DiffRules.diffrule(M, f, :goo)
modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)
1.0
elseif f === :log1mexp
-1.0
elseif f === :log2mexp
-0.5
else
0.0
end
@eval begin
let
goo = rand() + $modifier
@test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05)
# test for 2pi functions
if "mod2pi" == string($M.$f)
goo = 4pi + $modifier
@test NaN === $deriv
@testset "($M, $f, $arity)" for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
for T in [Float32, Float64]
(M, f, arity) ∈ non_diffeable_arg_functions && continue
if arity == 1
@test DiffRules.hasdiffrule(M, f, 1)
deriv = DiffRules.diffrule(M, f, :goo)
@eval begin
let
goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth))
# avoid singularities with finite differencing
rand($T) + $T(1.5)
elseif $(f in (:log, :airyaix, :airyaiprimex))
# avoid singularities with finite differencing
rand($T) + $T(0.5)
elseif $(f === :log1mexp)
rand($T) - one($T)
elseif $(f in (:log2mexp, :erfinv))
rand($T) - $T(0.5)
else
rand($T)
end
# We're happy with types with the correct promotion behavior, e.g.
# it's fine to return `1` as a derivative despite input being `Float64`.
@test promote_type(typeof($deriv), $T) === $T
@test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3
# test for 2pi functions
if $(f === :mod2pi)
goo = 4 * pi
@test NaN === $deriv
end
end
end
end
elseif arity == 2
@test DiffRules.hasdiffrule(M, f, 2)
derivs = DiffRules.diffrule(M, f, :foo, :bar)
@eval begin
let
if "mod" == string($M.$f)
foo, bar = rand() + 13, rand() + 5 # make sure x/y is not integer
else
foo, bar = rand(1:10), rand()
elseif arity == 2
@test DiffRules.hasdiffrule(M, f, 2)
derivs = DiffRules.diffrule(M, f, :foo, :bar)
@eval begin
let
foo, bar = if $(f === :mod)
rand($T) + 13, rand($T) + 5 # make sure x/y is not integer
elseif $(f === :polygamma)
rand(1:10), rand($T) # only supports integers as first arguments
elseif $(f in (:bessely, :besselyx))
# avoid singularities with finite differencing
rand($T), rand($T) + $T(0.5)
elseif $(f === :log)
# avoid singularities with finite differencing
rand($T) + $T(1.5), rand($T)
elseif $(f === :^)
# avoid singularities with finite differencing
rand($T) + $T(0.5), rand($T)
else
rand($T), rand($T)
end
dx, dy = $(derivs[1]), $(derivs[2])
if !(isnan(dx))
@test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3

# Check type, if applicable.
@test promote_type(typeof(real(dx)), $T) === $T
end
if !(isnan(dy))
@test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3

# Check type, if applicable.
@test promote_type(typeof(real(dy)), $T) === $T
end
end
dx, dy = $(derivs[1]), $(derivs[2])
end
elseif arity == 3
#=
@test DiffRules.hasdiffrule(M, f, 3)
derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo)
@eval begin
foo, bar, goo = randn(3)
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
if !(isnan(dx))
@test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05)
@test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05)
end
if !(isnan(dy))
@test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05)
@test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05)
end
if !(isnan(dz))
@test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05)
end
end
=#
end
elseif arity == 3
#=
@test DiffRules.hasdiffrule(M, f, 3)
derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo)
@eval begin
foo, bar, goo = randn(3)
dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3])
if !(isnan(dx))
@test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05)
end
if !(isnan(dy))
@test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05)
end
if !(isnan(dz))
@test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05)
end
end
=#
end
end

Expand Down