diff --git a/Project.toml b/Project.toml index 769ec33..b0e646a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,25 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.9.0" +version = "1.9.1" [deps] +IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +IrrationalConstants = "0.1.1" LogExpFunctions = "0.3.2" NaNMath = "0.3" SpecialFunctions = "0.10, 1.0, 2" julia = "1.3" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "FiniteDifferences"] diff --git a/src/DiffRules.jl b/src/DiffRules.jl index b70cae5..67c76f2 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -2,6 +2,8 @@ __precompile__() module DiffRules +using IrrationalConstants: logtwo, logten, twoπ, sqrtπ, invsqrtπ + include("api.jl") include("rules.jl") diff --git a/src/rules.jl b/src/rules.jl index ebc1ac7..d1ca0a2 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -5,19 +5,19 @@ # unary # #-------# -@define_diffrule Base.:+(x) = :( 1 ) -@define_diffrule Base.:-(x) = :( -1 ) +@define_diffrule Base.:+(x) = :( 1 ) +@define_diffrule Base.:-(x) = :( -1 ) @define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) ) @define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) ) @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / log(10) ) -@define_diffrule Base.log2(x) = :( inv($x) / log(2) ) +@define_diffrule Base.log10(x) = :( inv($x) / $logten ) +@define_diffrule Base.log2(x) = :( inv($x) / $logtwo ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) -@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) ) -@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) ) +@define_diffrule Base.exp2(x) = :( exp2($x) * $logtwo ) +@define_diffrule Base.exp10(x) = :( exp10($x) * $logten ) @define_diffrule Base.expm1(x) = :( exp($x) ) @define_diffrule Base.sin(x) = :( cos($x) ) @define_diffrule Base.cos(x) = :( -sin($x) ) @@ -25,26 +25,26 @@ @define_diffrule Base.sec(x) = :( sec($x) * tan($x) ) @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) -@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) ) -@define_diffrule Base.cosd(x) = :( -(π / 180) * sind($x) ) -@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) ) -@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) ) -@define_diffrule Base.cscd(x) = :( -(π / 180) * cscd($x) * cotd($x) ) -@define_diffrule Base.cotd(x) = :( -(π / 180) * (1 + cotd($x)^2) ) +@define_diffrule Base.sind(x) = :( deg2rad(cosd($x)) ) +@define_diffrule Base.cosd(x) = :( - deg2rad(sind($x)) ) +@define_diffrule Base.tand(x) = :( deg2rad(1 + tand($x)^2) ) +@define_diffrule Base.secd(x) = :( deg2rad(secd($x) * tand($x)) ) +@define_diffrule Base.cscd(x) = :( - deg2rad(cscd($x) * cotd($x)) ) +@define_diffrule Base.cotd(x) = :( - deg2rad(1 + cotd($x)^2) ) @define_diffrule Base.sinpi(x) = :( π * cospi($x) ) -@define_diffrule Base.cospi(x) = :( -π * sinpi($x) ) +@define_diffrule Base.cospi(x) = :( -(π * sinpi($x)) ) @define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) ) @define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) ) @define_diffrule Base.atan(x) = :( inv(1 + $x^2) ) @define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acot(x) = :( -inv(1 + $x^2) ) -@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) ) -@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) ) -@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) ) -@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) ) +@define_diffrule Base.asind(x) = :( inv(deg2rad(sqrt(1 - $x^2))) ) +@define_diffrule Base.acosd(x) = :( -inv(deg2rad(sqrt(1 - $x^2))) ) +@define_diffrule Base.atand(x) = :( inv(deg2rad(1 + $x^2)) ) +@define_diffrule Base.asecd(x) = :( inv(deg2rad(abs($x) * sqrt($x^2 - 1))) ) +@define_diffrule Base.acscd(x) = :( -inv(deg2rad(abs($x) * sqrt($x^2 - 1))) ) +@define_diffrule Base.acotd(x) = :( -inv(deg2rad(1 + $x^2)) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) @define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 ) @@ -58,16 +58,15 @@ @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) @define_diffrule Base.sinc(x) = :( cosc($x) ) -@define_diffrule Base.deg2rad(x) = :( π / 180 ) -@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) -@define_diffrule Base.rad2deg(x) = :( 180 / π ) - +@define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) ) +@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) ) +@define_diffrule Base.rad2deg(x) = :( rad2deg(one($x)) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) +@define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) ) # We provide this hook for special number types like `Interval` # that need their own special definition of `abs`. @@ -88,8 +87,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) @define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN -@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) ) -@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) ) +@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) +@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) @define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) @@ -113,17 +112,17 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) ) @define_diffrule SpecialFunctions.erfinv(x) = - :( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) ) + :( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfc(x) = :( -($invsqrtπ * exp(-$x^2) * 2) ) @define_diffrule SpecialFunctions.erfcinv(x) = - :( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) ) -@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) ) + :( -($sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( $invsqrtπ * exp($x^2) * 2 ) @define_diffrule SpecialFunctions.erfcx(x) = - :( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) ) + :( 2 * (($x * SpecialFunctions.erfcx($x)) - $invsqrtπ) ) @define_diffrule SpecialFunctions.logerfcx(x) = - :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * sqrt(π))) ) + :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * $sqrtπ)) ) @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) @define_diffrule SpecialFunctions.digamma(x) = @@ -217,8 +216,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) @define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(2) ) -@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(10) ) +@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) +@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) diff --git a/test/runtests.jl b/test/runtests.jl index 072c962..6618a36 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,82 +1,108 @@ using DiffRules using Test +using FiniteDifferences + +using IrrationalConstants: fourπ import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -function finitediff(f, x) - ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) - return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) -end +# Set `max_range` to avoid domain errors. +const finitediff = central_fdm(5, 1, max_range=1e-3) @testset "DiffRules" begin @testset "check rules" begin non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, :ifelse, 3)] -for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) - (M, f, arity) ∈ non_diffeable_arg_functions && continue - if arity == 1 - @test DiffRules.hasdiffrule(M, f, 1) - deriv = DiffRules.diffrule(M, f, :goo) - modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) - 1.0 - elseif f === :log1mexp - -1.0 - elseif f === :log2mexp - -0.5 - else - 0.0 - end - @eval begin - let - goo = rand() + $modifier - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) - # test for 2pi functions - if "mod2pi" == string($M.$f) - goo = 4pi + $modifier - @test NaN === $deriv +@testset "($M, $f, $arity)" for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) + for T in [Float32, Float64] + (M, f, arity) ∈ non_diffeable_arg_functions && continue + if arity == 1 + @test DiffRules.hasdiffrule(M, f, 1) + deriv = DiffRules.diffrule(M, f, :goo) + @eval begin + let + goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) + # avoid singularities with finite differencing + rand($T) + $T(1.5) + elseif $(f in (:log, :airyaix, :airyaiprimex)) + # avoid singularities with finite differencing + rand($T) + $T(0.5) + elseif $(f === :log1mexp) + rand($T) - one($T) + elseif $(f in (:log2mexp, :erfinv)) + rand($T) - $T(0.5) + else + rand($T) + end + # We're happy with types with the correct promotion behavior, e.g. + # it's fine to return `1` as a derivative despite input being `Float64`. + @test promote_type(typeof($deriv), $T) === $T + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 + # test for 2pi functions + if $(f === :mod2pi) + goo = 4 * pi + @test NaN === $deriv + end end end - end - elseif arity == 2 - @test DiffRules.hasdiffrule(M, f, 2) - derivs = DiffRules.diffrule(M, f, :foo, :bar) - @eval begin - let - if "mod" == string($M.$f) - foo, bar = rand() + 13, rand() + 5 # make sure x/y is not integer - else - foo, bar = rand(1:10), rand() + elseif arity == 2 + @test DiffRules.hasdiffrule(M, f, 2) + derivs = DiffRules.diffrule(M, f, :foo, :bar) + @eval begin + let + foo, bar = if $(f === :mod) + rand($T) + 13, rand($T) + 5 # make sure x/y is not integer + elseif $(f === :polygamma) + rand(1:10), rand($T) # only supports integers as first arguments + elseif $(f in (:bessely, :besselyx)) + # avoid singularities with finite differencing + rand($T), rand($T) + $T(0.5) + elseif $(f === :log) + # avoid singularities with finite differencing + rand($T) + $T(1.5), rand($T) + elseif $(f === :^) + # avoid singularities with finite differencing + rand($T) + $T(0.5), rand($T) + else + rand($T), rand($T) + end + dx, dy = $(derivs[1]), $(derivs[2]) + if !(isnan(dx)) + @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3 + + # Check type, if applicable. + @test promote_type(typeof(real(dx)), $T) === $T + end + if !(isnan(dy)) + @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3 + + # Check type, if applicable. + @test promote_type(typeof(real(dy)), $T) === $T + end end - dx, dy = $(derivs[1]), $(derivs[2]) + end + elseif arity == 3 + #= + @test DiffRules.hasdiffrule(M, f, 3) + derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) + @eval begin + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) end if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) + end + if !(isnan(dz)) + @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) end end + =# end - elseif arity == 3 - #= - @test DiffRules.hasdiffrule(M, f, 3) - derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) - @eval begin - foo, bar, goo = randn(3) - dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) - end - if !(isnan(dz)) - @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) - end - end - =# end end