From 2abc4ccf0bc0131c1303c87fe1feb1eabf37fb29 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 17 Jan 2021 15:13:28 +0000 Subject: [PATCH 01/48] fixed issues with return-type for several rules --- src/rules.jl | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index e31f5d4..adfae78 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -25,26 +25,26 @@ @define_diffrule Base.sec(x) = :( sec($x) * tan($x) ) @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) -@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) ) -@define_diffrule Base.cosd(x) = :( -(π / 180) * sind($x) ) -@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) ) -@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) ) -@define_diffrule Base.cscd(x) = :( -(π / 180) * cscd($x) * cotd($x) ) -@define_diffrule Base.cotd(x) = :( -(π / 180) * (1 + cotd($x)^2) ) -@define_diffrule Base.sinpi(x) = :( π * cospi($x) ) -@define_diffrule Base.cospi(x) = :( -π * sinpi($x) ) +@define_diffrule Base.sind(x) = :( (oftype($x, π) / 180) * cosd($x) ) +@define_diffrule Base.cosd(x) = :( -(oftype($x, π) / 180) * sind($x) ) +@define_diffrule Base.tand(x) = :( (oftype($x, π) / 180) * (1 + tand($x)^2) ) +@define_diffrule Base.secd(x) = :( (oftype($x, π) / 180) * secd($x) * tand($x) ) +@define_diffrule Base.cscd(x) = :( -(oftype($x, π) / 180) * cscd($x) * cotd($x) ) +@define_diffrule Base.cotd(x) = :( -(oftype($x, π) / 180) * (1 + cotd($x)^2) ) +@define_diffrule Base.sinpi(x) = :( oftype($x, π) * cospi($x) ) +@define_diffrule Base.cospi(x) = :( -oftype($x, π) * sinpi($x) ) @define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) ) @define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) ) @define_diffrule Base.atan(x) = :( inv(1 + $x^2) ) @define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acot(x) = :( -inv(1 + $x^2) ) -@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) ) -@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) ) -@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) ) -@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) ) +@define_diffrule Base.asind(x) = :( 180 / oftype($x, π) / sqrt(1 - $x^2) ) +@define_diffrule Base.acosd(x) = :( -180 / oftype($x, π) / sqrt(1 - $x^2) ) +@define_diffrule Base.atand(x) = :( 180 / oftype($x, π) / (1 + $x^2) ) +@define_diffrule Base.asecd(x) = :( 180 / oftype($x, π) / abs($x) / sqrt($x^2 - 1) ) +@define_diffrule Base.acscd(x) = :( -180 / oftype($x, π) / abs($x) / sqrt($x^2 - 1) ) +@define_diffrule Base.acotd(x) = :( -180 / oftype($x, π) / (1 + $x^2) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) @define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 ) @@ -57,9 +57,9 @@ @define_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) ) @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) -@define_diffrule Base.deg2rad(x) = :( π / 180 ) +@define_diffrule Base.deg2rad(x) = :( oftype($x, π) / 180 ) @define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) -@define_diffrule Base.rad2deg(x) = :( 180 / π ) +@define_diffrule Base.rad2deg(x) = :( 180 / oftype($x, π) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = @@ -100,15 +100,15 @@ end # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) ) +@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(oftype($x, π))) * exp(-$x * $x) ) @define_diffrule SpecialFunctions.erfinv(x) = - :( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) ) + :( (sqrt(oftype($x, π)) / 2) * exp(SpecialFunctions.erfinv($x)^2) ) +@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(oftype($x, π))) * exp(-$x * $x) ) @define_diffrule SpecialFunctions.erfcinv(x) = - :( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) ) -@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) ) + :( -(sqrt(oftype($x, π)) / 2) * exp(SpecialFunctions.erfcinv($x)^2) ) +@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(oftype($x, π))) * exp($x * $x) ) @define_diffrule SpecialFunctions.erfcx(x) = - :( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) ) + :( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(oftype($x, π))) ) @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) @define_diffrule SpecialFunctions.digamma(x) = From 184ba543049e804484a8ed6e8647ce2cb14a790f Mon Sep 17 00:00:00 2001 From: tor Date: Thu, 11 Mar 2021 15:40:34 +0100 Subject: [PATCH 02/48] initial work on inferring return type using intermediate computations --- src/rules.jl | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index adfae78..bda81b7 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -99,16 +99,36 @@ end # unary # #-------# +@define_diffrule SpecialFunctions.erf(x) = quote + tmp = exp(-$x * $x) + (oftype(tmp, 2 / sqrt(π)) * tmp) +end + +@define_diffrule SpecialFunctions.erfinv(x) = quote + tmp = exp(SpecialFunctions.erfinv($x)^2) + (oftype(tmp, sqrt(π)) / 2) * tmp +end + +@define_diffrule SpecialFunctions.erfc(x) = quote + tmp = exp(-$x * $x) + -oftype(tmp, (2 / sqrt(π))) * tmp +end + +@define_diffrule SpecialFunctions.erfcinv(x) = quote + tmp = exp(SpecialFunctions.erfcinv($x)^2) + -(oftype(tmp, sqrt(π)) / 2) * tmp +end + +@define_diffrule SpecialFunctions.erfi(x) = quote + tmp = exp($x * $x) + oftype(tmp, (2 / sqrt(π))) * tmp +end + +@define_diffrule SpecialFunctions.erfcx(x) = quote + tmp = (2 * $x * SpecialFunctions.erfcx($x)) + tmp - oftype(tmp, (2 / sqrt(π))) +end -@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(oftype($x, π))) * exp(-$x * $x) ) -@define_diffrule SpecialFunctions.erfinv(x) = - :( (sqrt(oftype($x, π)) / 2) * exp(SpecialFunctions.erfinv($x)^2) ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(oftype($x, π))) * exp(-$x * $x) ) -@define_diffrule SpecialFunctions.erfcinv(x) = - :( -(sqrt(oftype($x, π)) / 2) * exp(SpecialFunctions.erfcinv($x)^2) ) -@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(oftype($x, π))) * exp($x * $x) ) -@define_diffrule SpecialFunctions.erfcx(x) = - :( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(oftype($x, π))) ) @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) @define_diffrule SpecialFunctions.digamma(x) = From c60fd7ba60dc7f67bb2e5980443b92d42f46369a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:03:28 +0200 Subject: [PATCH 03/48] removed oftype where possible --- src/rules.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index bda81b7..90d2217 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -25,26 +25,26 @@ @define_diffrule Base.sec(x) = :( sec($x) * tan($x) ) @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) -@define_diffrule Base.sind(x) = :( (oftype($x, π) / 180) * cosd($x) ) -@define_diffrule Base.cosd(x) = :( -(oftype($x, π) / 180) * sind($x) ) -@define_diffrule Base.tand(x) = :( (oftype($x, π) / 180) * (1 + tand($x)^2) ) -@define_diffrule Base.secd(x) = :( (oftype($x, π) / 180) * secd($x) * tand($x) ) -@define_diffrule Base.cscd(x) = :( -(oftype($x, π) / 180) * cscd($x) * cotd($x) ) -@define_diffrule Base.cotd(x) = :( -(oftype($x, π) / 180) * (1 + cotd($x)^2) ) -@define_diffrule Base.sinpi(x) = :( oftype($x, π) * cospi($x) ) -@define_diffrule Base.cospi(x) = :( -oftype($x, π) * sinpi($x) ) +@define_diffrule Base.sind(x) = :( π * cosd($x) / 180 ) +@define_diffrule Base.cosd(x) = :( -π * sind($x) / 180 ) +@define_diffrule Base.tand(x) = :( π * (1 + tand($x)^2) / 180 ) +@define_diffrule Base.secd(x) = :( π * secd($x) * tand($x) / 180 ) +@define_diffrule Base.cscd(x) = :( -π * cscd($x) * cotd($x) / 180 ) +@define_diffrule Base.cotd(x) = :( -π * (1 + cotd($x)^2) / 180 ) +@define_diffrule Base.sinpi(x) = :( π * cospi($x) ) +@define_diffrule Base.cospi(x) = :( -π * sinpi($x) ) @define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) ) @define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) ) @define_diffrule Base.atan(x) = :( inv(1 + $x^2) ) @define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acot(x) = :( -inv(1 + $x^2) ) -@define_diffrule Base.asind(x) = :( 180 / oftype($x, π) / sqrt(1 - $x^2) ) -@define_diffrule Base.acosd(x) = :( -180 / oftype($x, π) / sqrt(1 - $x^2) ) -@define_diffrule Base.atand(x) = :( 180 / oftype($x, π) / (1 + $x^2) ) -@define_diffrule Base.asecd(x) = :( 180 / oftype($x, π) / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acscd(x) = :( -180 / oftype($x, π) / abs($x) / sqrt($x^2 - 1) ) -@define_diffrule Base.acotd(x) = :( -180 / oftype($x, π) / (1 + $x^2) ) +@define_diffrule Base.asind(x) = :( 180 / (π / sqrt(1 - $x^2)) ) +@define_diffrule Base.acosd(x) = :( -180 / (π / sqrt(1 - $x^2)) ) +@define_diffrule Base.atand(x) = :( 180 / (π / (1 + $x^2)) ) +@define_diffrule Base.asecd(x) = :( 180 / (π / abs($x) / sqrt($x^2 - 1)) ) +@define_diffrule Base.acscd(x) = :( -180 / (π / abs($x) / sqrt($x^2 - 1)) ) +@define_diffrule Base.acotd(x) = :( -180 / (π / (1 + $x^2)) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) @define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 ) @@ -57,9 +57,9 @@ @define_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) ) @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) -@define_diffrule Base.deg2rad(x) = :( oftype($x, π) / 180 ) +@define_diffrule Base.deg2rad(x) = :( oftype($x, π) / 180 ) @define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) -@define_diffrule Base.rad2deg(x) = :( 180 / oftype($x, π) ) +@define_diffrule Base.rad2deg(x) = :( 180 / oftype($x, π) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = From b440cdda4d3489400fe43d3e03e79ca0a4846cc1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:13:29 +0200 Subject: [PATCH 04/48] fixed stupid mistake --- src/rules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 90d2217..ce6d3c1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -39,12 +39,12 @@ @define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acot(x) = :( -inv(1 + $x^2) ) -@define_diffrule Base.asind(x) = :( 180 / (π / sqrt(1 - $x^2)) ) -@define_diffrule Base.acosd(x) = :( -180 / (π / sqrt(1 - $x^2)) ) -@define_diffrule Base.atand(x) = :( 180 / (π / (1 + $x^2)) ) -@define_diffrule Base.asecd(x) = :( 180 / (π / abs($x) / sqrt($x^2 - 1)) ) -@define_diffrule Base.acscd(x) = :( -180 / (π / abs($x) / sqrt($x^2 - 1)) ) -@define_diffrule Base.acotd(x) = :( -180 / (π / (1 + $x^2)) ) +@define_diffrule Base.asind(x) = :( 180 / (π * sqrt(1 - $x^2)) ) +@define_diffrule Base.acosd(x) = :( -180 / (π * sqrt(1 - $x^2)) ) +@define_diffrule Base.atand(x) = :( 180 / (π * (1 + $x^2)) ) +@define_diffrule Base.asecd(x) = :( 180 / (π * abs($x) / sqrt($x^2 - 1)) ) +@define_diffrule Base.acscd(x) = :( -180 / (π * abs($x) / sqrt($x^2 - 1)) ) +@define_diffrule Base.acotd(x) = :( -180 / (π * (1 + $x^2)) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) @define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 ) From 65c35596280698e48b787ed6c053ba076253d19f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:14:58 +0200 Subject: [PATCH 05/48] missed one in fix of stupid mistake --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index ce6d3c1..085fe2f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -42,8 +42,8 @@ @define_diffrule Base.asind(x) = :( 180 / (π * sqrt(1 - $x^2)) ) @define_diffrule Base.acosd(x) = :( -180 / (π * sqrt(1 - $x^2)) ) @define_diffrule Base.atand(x) = :( 180 / (π * (1 + $x^2)) ) -@define_diffrule Base.asecd(x) = :( 180 / (π * abs($x) / sqrt($x^2 - 1)) ) -@define_diffrule Base.acscd(x) = :( -180 / (π * abs($x) / sqrt($x^2 - 1)) ) +@define_diffrule Base.asecd(x) = :( 180 / (π * abs($x) * sqrt($x^2 - 1)) ) +@define_diffrule Base.acscd(x) = :( -180 / (π * abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acotd(x) = :( -180 / (π * (1 + $x^2)) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) From dcba293b03253a2ab3591a79a32bc0f9816c2a5c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:31:49 +0200 Subject: [PATCH 06/48] more fixes --- src/rules.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 085fe2f..6259822 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -5,19 +5,19 @@ # unary # #-------# -@define_diffrule Base.:+(x) = :( 1 ) -@define_diffrule Base.:-(x) = :( -1 ) +@define_diffrule Base.:+(x) = :( one($x) ) +@define_diffrule Base.:-(x) = :( -one($x) ) @define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) ) @define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) ) @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / log(10) ) -@define_diffrule Base.log2(x) = :( inv($x) / log(2) ) +@define_diffrule Base.log10(x) = :( inv($x) / log(oftype($x, 10)) ) +@define_diffrule Base.log2(x) = :( inv($x) / log(oftype($x, 2)) ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) -@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) ) -@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) ) +@define_diffrule Base.exp2(x) = :( exp2($x) * log(oftype($x, 2)) ) +@define_diffrule Base.exp10(x) = :( exp10($x) * log(oftype($x, 10)) ) @define_diffrule Base.expm1(x) = :( exp($x) ) @define_diffrule Base.sin(x) = :( cos($x) ) @define_diffrule Base.cos(x) = :( -sin($x) ) @@ -26,13 +26,13 @@ @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) @define_diffrule Base.sind(x) = :( π * cosd($x) / 180 ) -@define_diffrule Base.cosd(x) = :( -π * sind($x) / 180 ) +@define_diffrule Base.cosd(x) = :( -(π * sind($x)) / 180 ) @define_diffrule Base.tand(x) = :( π * (1 + tand($x)^2) / 180 ) @define_diffrule Base.secd(x) = :( π * secd($x) * tand($x) / 180 ) -@define_diffrule Base.cscd(x) = :( -π * cscd($x) * cotd($x) / 180 ) -@define_diffrule Base.cotd(x) = :( -π * (1 + cotd($x)^2) / 180 ) +@define_diffrule Base.cscd(x) = :( -(π * cscd($x) * cotd($x) / 180) ) +@define_diffrule Base.cotd(x) = :( -(π * (1 + cotd($x)^2) / 180) ) @define_diffrule Base.sinpi(x) = :( π * cospi($x) ) -@define_diffrule Base.cospi(x) = :( -π * sinpi($x) ) +@define_diffrule Base.cospi(x) = :( -(π * sinpi($x)) ) @define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) ) @define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) ) @define_diffrule Base.atan(x) = :( inv(1 + $x^2) ) @@ -58,13 +58,13 @@ @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) @define_diffrule Base.deg2rad(x) = :( oftype($x, π) / 180 ) -@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) +@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? oftype($x, NaN) : one($x) ) @define_diffrule Base.rad2deg(x) = :( 180 / oftype($x, π) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule Base.transpose(x) = :( 1 ) +@define_diffrule Base.transpose(x) = :( one($x) ) @define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) # We provide this hook for special number types like `Interval` @@ -216,14 +216,14 @@ end @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(2) ) -@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(10) ) +@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(oftype($x, oftype($x, 2))) ) +@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(oftype($x, oftype($x, 10))) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) From 3f71077f976bcc5ee23a4560e06b621dcf512e3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:31:55 +0200 Subject: [PATCH 07/48] added tests for different types --- test/runtests.jl | 58 ++++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ede8617..3bc89d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,32 +18,38 @@ end non_numeric_arg_functions = [(:Base, :rem2pi, 2)] -for (M, f, arity) in DiffRules.diffrules() - (M, f, arity) ∈ non_numeric_arg_functions && continue - if arity == 1 - @test DiffRules.hasdiffrule(M, f, 1) - deriv = DiffRules.diffrule(M, f, :goo) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 - @eval begin - goo = rand() + $modifier - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) - # test for 2pi functions - if "mod2pi" == string($M.$f) - goo = 4pi + $modifier - @test NaN === $deriv - end - end - elseif arity == 2 - @test DiffRules.hasdiffrule(M, f, 2) - derivs = DiffRules.diffrule(M, f, :foo, :bar) - @eval begin - foo, bar = rand(1:10), rand() - dx, dy = $(derivs[1]), $(derivs[2]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) +for T in [Float32, Float64] + for (M, f, arity) in DiffRules.diffrules() + @testset "$M.$f $(arity)" begin + (M, f, arity) ∈ non_numeric_arg_functions && continue + if arity == 1 + @test DiffRules.hasdiffrule(M, f, 1) + deriv = DiffRules.diffrule(M, f, :goo) + modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 + @eval begin + goo = $T(rand() + $modifier) + @test $deriv isa $T + @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + # test for 2pi functions + if "mod2pi" == string($M.$f) + goo = 4pi + $modifier + @test NaN === $deriv + end + end + elseif arity == 2 + # TODO: Add test for types. + @test DiffRules.hasdiffrule(M, f, 2) + derivs = DiffRules.diffrule(M, f, :foo, :bar) + @eval begin + foo, bar = rand(1:10), rand() + dx, dy = $(derivs[1]), $(derivs[2]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + end + end end end end From ee42f7ff2ed81c66d39b84e09c0d184c7b6bd7fb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:34:45 +0200 Subject: [PATCH 08/48] removed a couple of overdone oftypes --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 6259822..a84b369 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -222,8 +222,8 @@ end @define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) @define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(oftype($x, oftype($x, 2))) ) -@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(oftype($x, oftype($x, 10))) ) +@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(oftype($x, 2)) ) +@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(oftype($x, 10)) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) From 548dc2cffb3189088720ebc2ea7b620124216f37 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Apr 2021 06:40:01 +0200 Subject: [PATCH 09/48] moved a single paranthesis --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index a84b369..2fb0fa6 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -26,7 +26,7 @@ @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) @define_diffrule Base.sind(x) = :( π * cosd($x) / 180 ) -@define_diffrule Base.cosd(x) = :( -(π * sind($x)) / 180 ) +@define_diffrule Base.cosd(x) = :( -(π * sind($x) / 180) ) @define_diffrule Base.tand(x) = :( π * (1 + tand($x)^2) / 180 ) @define_diffrule Base.secd(x) = :( π * secd($x) * tand($x) / 180 ) @define_diffrule Base.cscd(x) = :( -(π * cscd($x) * cotd($x) / 180) ) From ee6fe3e71d97ef2bcee169dd4fa903d6d2d3210c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 15:42:56 +0100 Subject: [PATCH 10/48] use irrationals to simplify type-promotion --- src/DiffRules.jl | 1 + src/constants.jl | 7 +++++++ src/rules.jl | 47 ++++++++++++----------------------------------- 3 files changed, 20 insertions(+), 35 deletions(-) create mode 100644 src/constants.jl diff --git a/src/DiffRules.jl b/src/DiffRules.jl index b70cae5..7a02c93 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -3,6 +3,7 @@ __precompile__() module DiffRules include("api.jl") +include("constants.jl") include("rules.jl") end # module diff --git a/src/constants.jl b/src/constants.jl new file mode 100644 index 0000000..be7eb9e --- /dev/null +++ b/src/constants.jl @@ -0,0 +1,7 @@ +import Base: @irrational + +@irrational logtwo 0.6931471805599453094 log(big(2)) +@irrational logten 2.302585092994046 log(big(10)) + +@irrational twoinvsqrtπ 1.1283791670955126 2 / sqrt(big(π)) +@irrational halfsqrtπ 0.886226925452758 sqrt(big(π)) / 2 diff --git a/src/rules.jl b/src/rules.jl index 2fb0fa6..518ff2d 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -12,12 +12,12 @@ @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / log(oftype($x, 10)) ) -@define_diffrule Base.log2(x) = :( inv($x) / log(oftype($x, 2)) ) +@define_diffrule Base.log10(x) = :( inv($x) / $(DiffRules.logten) ) +@define_diffrule Base.log2(x) = :( inv($x) / $(DiffRules.logtwo) ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) -@define_diffrule Base.exp2(x) = :( exp2($x) * log(oftype($x, 2)) ) -@define_diffrule Base.exp10(x) = :( exp10($x) * log(oftype($x, 10)) ) +@define_diffrule Base.exp2(x) = :( exp2($x) * $(DiffRules.logtwo) ) +@define_diffrule Base.exp10(x) = :( exp10($x) * $(DiffRules.logten) ) @define_diffrule Base.expm1(x) = :( exp($x) ) @define_diffrule Base.sin(x) = :( cos($x) ) @define_diffrule Base.cos(x) = :( -sin($x) ) @@ -57,9 +57,9 @@ @define_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) ) @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) -@define_diffrule Base.deg2rad(x) = :( oftype($x, π) / 180 ) +@define_diffrule Base.deg2rad(x) = :( oftype($x, π / 180) ) @define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? oftype($x, NaN) : one($x) ) -@define_diffrule Base.rad2deg(x) = :( 180 / oftype($x, π) ) +@define_diffrule Base.rad2deg(x) = :( oftype($x, 180 / π) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = @@ -99,35 +99,12 @@ end # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = quote - tmp = exp(-$x * $x) - (oftype(tmp, 2 / sqrt(π)) * tmp) -end - -@define_diffrule SpecialFunctions.erfinv(x) = quote - tmp = exp(SpecialFunctions.erfinv($x)^2) - (oftype(tmp, sqrt(π)) / 2) * tmp -end - -@define_diffrule SpecialFunctions.erfc(x) = quote - tmp = exp(-$x * $x) - -oftype(tmp, (2 / sqrt(π))) * tmp -end - -@define_diffrule SpecialFunctions.erfcinv(x) = quote - tmp = exp(SpecialFunctions.erfcinv($x)^2) - -(oftype(tmp, sqrt(π)) / 2) * tmp -end - -@define_diffrule SpecialFunctions.erfi(x) = quote - tmp = exp($x * $x) - oftype(tmp, (2 / sqrt(π))) * tmp -end - -@define_diffrule SpecialFunctions.erfcx(x) = quote - tmp = (2 * $x * SpecialFunctions.erfcx($x)) - tmp - oftype(tmp, (2 / sqrt(π))) -end +@define_diffrule SpecialFunctions.erf(x) = :( exp(-$x * $x) * $(DiffRules.twoinvsqrtπ) ) +@define_diffrule SpecialFunctions.erfinv(x) = :( exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.halfsqrtπ) ) +@define_diffrule SpecialFunctions.erfc(x) = :( -exp(-$x * $x) * $(DiffRules.twoinvsqrtπ) ) +@define_diffrule SpecialFunctions.erfcinv(x) = :( -exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.halfsqrtπ) ) +@define_diffrule SpecialFunctions.erfi(x) = :( exp($x * $x) / $(DiffRules.halfsqrtπ) ) +@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * $x * SpecialFunctions.erfcx($x) - $(DiffRules.twoinvsqrtπ) ) @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) From b149ba5543bc6622390a1b4cce81f22fcd5c92dc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:19:41 +0100 Subject: [PATCH 11/48] depend on IrrationalConstants.jl --- Project.toml | 2 ++ src/constants.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a93c816..d0cec68 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,13 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.0.2" [deps] +IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +IrrationalConstants = "0.1" NaNMath = "0.3" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" julia = "1" diff --git a/src/constants.jl b/src/constants.jl index be7eb9e..1ccdee8 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -1,6 +1,6 @@ import Base: @irrational +import IrrationalConstants: logtwo -@irrational logtwo 0.6931471805599453094 log(big(2)) @irrational logten 2.302585092994046 log(big(10)) @irrational twoinvsqrtπ 1.1283791670955126 2 / sqrt(big(π)) From 38b7fd40e99d1d869f8d4421e663ae2a42aa6f17 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 11 Aug 2021 21:39:21 +0100 Subject: [PATCH 12/48] fixed tests --- test/runtests.jl | 50 ++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 2b243c1..57ee8f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,9 +18,9 @@ end non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] -for T in [Float32, Float64] - for (M, f, arity) in DiffRules.diffrules() - @testset "$M.$f $(arity)" begin + +@testset "$M.$f $(arity)" for (M, f, arity) in DiffRules.diffrules() + for T in [Float32, Float64] (M, f, arity) ∈ non_numeric_arg_functions && continue if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) @@ -50,26 +50,26 @@ for T in [Float32, Float64] @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) end end - end - end - elseif arity == 3 - #= - @test DiffRules.hasdiffrule(M, f, 3) - derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) - @eval begin - foo, bar, goo = randn(3) - dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) - if !(isnan(dx)) + elseif arity == 3 + #= + @test DiffRules.hasdiffrule(M, f, 3) + derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) + @eval begin + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + if !(isnan(dx)) @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) - end - if !(isnan(dy)) + end + if !(isnan(dy)) @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) - end - if !(isnan(dz)) + end + if !(isnan(dz)) @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) + end + end + =# end end - =# end end @@ -92,12 +92,12 @@ end @test DiffRules.hasdiffrule(:Base, :ifelse, 3) derivs = DiffRules.diffrule(:Base, :ifelse, :foo, :bar, :goo) for cond in [true, false] - @eval begin - foo = $cond - bar, gee = randn(2) - dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) - @test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) - @test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) - end +@eval begin +foo = $cond +bar, gee = randn(2) +dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) +@test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) +@test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) +end end =# From 6618cb14619edf03f3c5558e552311748e2f1f45 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 11 Aug 2021 21:39:49 +0100 Subject: [PATCH 13/48] actually fixed tests --- test/runtests.jl | 83 ++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 57ee8f2..e304dc3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,54 +21,53 @@ non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] @testset "$M.$f $(arity)" for (M, f, arity) in DiffRules.diffrules() for T in [Float32, Float64] - (M, f, arity) ∈ non_numeric_arg_functions && continue - if arity == 1 - @test DiffRules.hasdiffrule(M, f, 1) - deriv = DiffRules.diffrule(M, f, :goo) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 - @eval begin - goo = $T(rand() + $modifier) - @test $deriv isa $T - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) - # test for 2pi functions - if "mod2pi" == string($M.$f) - goo = 4pi + $modifier - @test NaN === $deriv - end + (M, f, arity) ∈ non_numeric_arg_functions && continue + if arity == 1 + @test DiffRules.hasdiffrule(M, f, 1) + deriv = DiffRules.diffrule(M, f, :goo) + modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 + @eval begin + goo = $T(rand() + $modifier) + @test $deriv isa $T + @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + # test for 2pi functions + if "mod2pi" == string($M.$f) + goo = 4pi + $modifier + @test NaN === $deriv end - elseif arity == 2 - # TODO: Add test for types. - @test DiffRules.hasdiffrule(M, f, 2) - derivs = DiffRules.diffrule(M, f, :foo, :bar) - @eval begin - foo, bar = rand(1:10), rand() - dx, dy = $(derivs[1]), $(derivs[2]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) - end - end - elseif arity == 3 - #= - @test DiffRules.hasdiffrule(M, f, 3) - derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) - @eval begin - foo, bar, goo = randn(3) - dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + end + elseif arity == 2 + # TODO: Add test for types. + @test DiffRules.hasdiffrule(M, f, 2) + derivs = DiffRules.diffrule(M, f, :foo, :bar) + @eval begin + foo, bar = rand(1:10), rand() + dx, dy = $(derivs[1]), $(derivs[2]) if !(isnan(dx)) - @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) + @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) end if !(isnan(dy)) - @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) - end - if !(isnan(dz)) - @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) - end + @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) end - =# end + elseif arity == 3 + #= + @test DiffRules.hasdiffrule(M, f, 3) + derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) + @eval begin + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) + end + if !(isnan(dz)) + @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) + end + end + =# end end end From b89221ad00bce760ac5ccf5c6b0988fccbcab9ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Dec 2021 10:41:35 +0000 Subject: [PATCH 14/48] updated rules to use constants from IrrationalConstants --- src/DiffRules.jl | 2 ++ src/rules.jl | 26 +++++++++++++------------- test/runtests.jl | 6 +++--- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/DiffRules.jl b/src/DiffRules.jl index b70cae5..558a603 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -2,6 +2,8 @@ __precompile__() module DiffRules +using IrrationalConstants + include("api.jl") include("rules.jl") diff --git a/src/rules.jl b/src/rules.jl index 12f90f8..e0cd80e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -111,22 +111,22 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( exp(-$x * $x) * $(DiffRules.twoinvsqrtπ) ) -@define_diffrule SpecialFunctions.erfinv(x) = :( exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.halfsqrtπ) ) -@define_diffrule SpecialFunctions.erfc(x) = :( -exp(-$x * $x) * $(DiffRules.twoinvsqrtπ) ) -@define_diffrule SpecialFunctions.erfcinv(x) = :( -exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.halfsqrtπ) ) -@define_diffrule SpecialFunctions.erfi(x) = :( exp($x * $x) / $(DiffRules.halfsqrtπ) ) -@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * $x * SpecialFunctions.erfcx($x) - $(DiffRules.twoinvsqrtπ) ) - -@define_diffrule SpecialFunctions.erf(x) = :( DiffRules.twoinvsqrtπ * exp(-$x * $x) ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x * $x) * $(DiffRules.invsqrtπ) ) +@define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) +@define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $(DiffRules.invsqrtπ)) * 2 ) +@define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( 2 * exp($x * $x) * $(DiffRules.invsqrtπ) ) +@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * ($x * SpecialFunctions.erfcx($x) - $(DiffRules.invsqrtπ)) ) + +@define_diffrule SpecialFunctions.erf(x) = :( 2 * (DiffRules.invsqrtπ * exp(-$x * $x)) ) @define_diffrule SpecialFunctions.erfinv(x) = - :( DiffRules.halfsqrtπ * exp(SpecialFunctions.erfinv($x)^2) ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(DiffRules.twoinvsqrtπ * exp(-$x * $x)) ) + :( (DiffRules.sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfc(x) = :( -(DiffRules.invsqrtπ * exp(-$x * $x) * 2) ) @define_diffrule SpecialFunctions.erfcinv(x) = - :( -(DiffRules.halfsqrtπDiffrules. * exp(SpecialFunctions.erfcinv($x)^2)) ) -@define_diffrule SpecialFunctions.erfi(x) = :( DiffRules.twoinvsqrtπ * exp($x * $x) ) + :( -(DiffRules.sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( DiffRules.invsqrtπ * exp($x * $x) * 2 ) @define_diffrule SpecialFunctions.erfcx(x) = - :( (2 * $x * SpecialFunctions.erfcx($x)) - DiffRules.twoinvsqrtπ ) + :( 2 * (($x * SpecialFunctions.erfcx($x)) - DiffRules.invsqrtπ) ) @define_diffrule SpecialFunctions.logerfcx(x) = :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * DiffRules.sqrtπ)) ) @define_diffrule SpecialFunctions.dawson(x) = diff --git a/test/runtests.jl b/test/runtests.jl index c543597..a82289d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,7 +32,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end @eval begin let - goo = T(rand() + $modifier) + goo = $T(rand() + $modifier) @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) # test for 2pi functions if "mod2pi" == string($M.$f) @@ -47,9 +47,9 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @eval begin let if "mod" == string($M.$f) - foo, bar = T(rand() + 13), T(rand() + 5) # make sure x/y is not integer + foo, bar = $T(rand() + 13), $T(rand() + 5) # make sure x/y is not integer else - foo, bar = rand(1:10), T(rand()) + foo, bar = rand(1:10), $T(rand()) end dx, dy = $(derivs[1]), $(derivs[2]) if !(isnan(dx)) From a48a523452850b956b59cd1e6db6108cc60a2d11 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:18:22 +0000 Subject: [PATCH 15/48] Update src/rules.jl Co-authored-by: David Widmann --- src/rules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index e0cd80e..35aa434 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -39,12 +39,12 @@ @define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) ) @define_diffrule Base.acot(x) = :( -inv(1 + $x^2) ) -@define_diffrule Base.asind(x) = :( 180 / (π * sqrt(1 - $x^2)) ) -@define_diffrule Base.acosd(x) = :( -180 / (π * sqrt(1 - $x^2)) ) -@define_diffrule Base.atand(x) = :( 180 / (π * (1 + $x^2)) ) -@define_diffrule Base.asecd(x) = :( 180 / (π * abs($x) * sqrt($x^2 - 1)) ) -@define_diffrule Base.acscd(x) = :( -180 / (π * abs($x) * sqrt($x^2 - 1)) ) -@define_diffrule Base.acotd(x) = :( -180 / (π * (1 + $x^2)) ) +@define_diffrule Base.asind(x) = :( inv(deg2rad(sqrt(1 - $x^2))) ) +@define_diffrule Base.acosd(x) = :( -inv(deg2rad(sqrt(1 - $x^2))) ) +@define_diffrule Base.atand(x) = :( inv(deg2rad(1 + $x^2)) ) +@define_diffrule Base.asecd(x) = :( inv(deg2rad(abs($x) * sqrt($x^2 - 1))) ) +@define_diffrule Base.acscd(x) = :( -inv(deg2rad(abs($x) * sqrt($x^2 - 1))) ) +@define_diffrule Base.acotd(x) = :( -inv(deg2rad(1 + $x^2)) ) @define_diffrule Base.sinh(x) = :( cosh($x) ) @define_diffrule Base.cosh(x) = :( sinh($x) ) @define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 ) From 1aaec9f9e8e0ae844d86ebc1fde17d8dd5b9e411 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:22:37 +0000 Subject: [PATCH 16/48] Apply suggestions from code review Co-authored-by: David Widmann --- src/rules.jl | 14 +++++++------- test/runtests.jl | 7 ++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 35aa434..85c1602 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -12,7 +12,7 @@ @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / $(DiffRules.logten) ) +@define_diffrule Base.log10(x) = :( inv($x) / $logten ) @define_diffrule Base.log2(x) = :( inv($x) / $(DiffRules.logtwo) ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) @@ -58,9 +58,9 @@ @define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) ) @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) @define_diffrule Base.sinc(x) = :( cosc($x) ) -@define_diffrule Base.deg2rad(x) = :( oftype($x, π / 180) ) -@define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? oftype($x, NaN) : one($x) ) -@define_diffrule Base.rad2deg(x) = :( oftype($x, 180 / π) ) +@define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) ) +@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) ) +@define_diffrule Base.rad2deg(x) = :( rand2deg(one($x)) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = @@ -111,7 +111,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x * $x) * $(DiffRules.invsqrtπ) ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $invsqrtπ ) @define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) @define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $(DiffRules.invsqrtπ)) * 2 ) @define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) @@ -222,8 +222,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) @define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($x) / NaNMath.log(oftype($x, 2)) ) -@define_diffrule NaNMath.log10(x) = :( inv($x) / NaNMath.log(oftype($x, 10)) ) +@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) +@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) diff --git a/test/runtests.jl b/test/runtests.jl index a82289d..6b3533c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,7 +32,8 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end @eval begin let - goo = $T(rand() + $modifier) + goo = rand($T) + $modifier + @test $deriv isa $T @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) # test for 2pi functions if "mod2pi" == string($M.$f) @@ -47,9 +48,9 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @eval begin let if "mod" == string($M.$f) - foo, bar = $T(rand() + 13), $T(rand() + 5) # make sure x/y is not integer + foo, bar = rand($T) + 13, rand($T) + 5 # make sure x/y is not integer else - foo, bar = rand(1:10), $T(rand()) + foo, bar = rand(1:10), rand($T) end dx, dy = $(derivs[1]), $(derivs[2]) if !(isnan(dx)) From c8e644347acce165d83055f305564b198fbf0c70 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:28:02 +0000 Subject: [PATCH 17/48] reverted some changes from previous commit --- src/rules.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 85c1602..192e0e0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -12,7 +12,7 @@ @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / $logten ) +@define_diffrule Base.log10(x) = :( inv($x) / $(DiffRules.logten) ) @define_diffrule Base.log2(x) = :( inv($x) / $(DiffRules.logtwo) ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) @@ -111,7 +111,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $invsqrtπ ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $(DiffRules.invsqrtπ) ) @define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) @define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $(DiffRules.invsqrtπ)) * 2 ) @define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) @@ -222,8 +222,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) @define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) -@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) +@define_diffrule NaNMath.log2(x) = :( inv($(DiffRules.logtwo) * $x) ) +@define_diffrule NaNMath.log10(x) = :( inv($(DiffRules.logten) * $x) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) From 34892e0024836bd55d237b39e204720ee1ad3ae2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:29:17 +0000 Subject: [PATCH 18/48] fixed a typo --- src/rules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 192e0e0..6636831 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -118,13 +118,13 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule SpecialFunctions.erfi(x) = :( 2 * exp($x * $x) * $(DiffRules.invsqrtπ) ) @define_diffrule SpecialFunctions.erfcx(x) = :( 2 * ($x * SpecialFunctions.erfcx($x) - $(DiffRules.invsqrtπ)) ) -@define_diffrule SpecialFunctions.erf(x) = :( 2 * (DiffRules.invsqrtπ * exp(-$x * $x)) ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($(DiffRules.invsqrtπ) * exp(-$x^2)) ) @define_diffrule SpecialFunctions.erfinv(x) = :( (DiffRules.sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(DiffRules.invsqrtπ * exp(-$x * $x) * 2) ) +@define_diffrule SpecialFunctions.erfc(x) = :( -(DiffRules.invsqrtπ * exp(-$x^2) * 2) ) @define_diffrule SpecialFunctions.erfcinv(x) = :( -(DiffRules.sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 ) -@define_diffrule SpecialFunctions.erfi(x) = :( DiffRules.invsqrtπ * exp($x * $x) * 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( DiffRules.invsqrtπ * exp($x^2) * 2 ) @define_diffrule SpecialFunctions.erfcx(x) = :( 2 * (($x * SpecialFunctions.erfcx($x)) - DiffRules.invsqrtπ) ) @define_diffrule SpecialFunctions.logerfcx(x) = From 97445d4cd915280ef5f0de92327ef33ab8021d0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:40:10 +0000 Subject: [PATCH 19/48] fixed typo --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 6636831..e860643 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -59,8 +59,8 @@ @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) @define_diffrule Base.sinc(x) = :( cosc($x) ) @define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) ) -@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) ) -@define_diffrule Base.rad2deg(x) = :( rand2deg(one($x)) ) +@define_diffrule Base.mod2pi(x) = :( isinteger($x / $(DiffRules.twoπ)) ? oftype(float($x), NaN) : one(float($x)) ) +@define_diffrule Base.rad2deg(x) = :( rad2deg(one($x)) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = From 2dee9c371998fb9263c3c26a9d135efa72b3bb2d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:40:16 +0000 Subject: [PATCH 20/48] fixed type-conversion in tests --- test/runtests.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6b3533c..90a3aa4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,13 +22,13 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) - 1.0 + one(T) elseif f === :log1mexp - -1.0 + -one(T) elseif f === :log2mexp - -0.5 + -(one(T) / 2) else - 0.0 + zero(T) end @eval begin let @@ -37,8 +37,8 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) # test for 2pi functions if "mod2pi" == string($M.$f) - goo = 4pi + $modifier - @test NaN === $deriv + goo = $(DiffRules.fourπ) + $modifier + @test $T(NaN) === $deriv end end end From 0f5b52c0643bda9753e99bf506091820b10539f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:42:52 +0000 Subject: [PATCH 21/48] Apply suggestions from code review Co-authored-by: David Widmann --- src/rules.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index e860643..c48b045 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -216,11 +216,11 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) @define_diffrule NaNMath.log2(x) = :( inv($(DiffRules.logtwo) * $x) ) @define_diffrule NaNMath.log10(x) = :( inv($(DiffRules.logten) * $x) ) From a0b40e12eeec5fded16955dbdd25df81725692bd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:46:01 +0000 Subject: [PATCH 22/48] reverse previous commit due to https://github.com/mlubin/NaNMath.jl/issues/47 --- src/rules.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index c48b045..e860643 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -216,11 +216,11 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) @define_diffrule NaNMath.log2(x) = :( inv($(DiffRules.logtwo) * $x) ) @define_diffrule NaNMath.log10(x) = :( inv($(DiffRules.logten) * $x) ) From 746a1572dfb6a6bc7cfda88fcde82d907fc863ab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 17:46:12 +0000 Subject: [PATCH 23/48] reverse previous commit due to https://github.com/mlubin/NaNMath.jl/issues/47 --- src/rules.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index e860643..c48b045 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -216,11 +216,11 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) @define_diffrule NaNMath.log2(x) = :( inv($(DiffRules.logtwo) * $x) ) @define_diffrule NaNMath.log10(x) = :( inv($(DiffRules.logten) * $x) ) From 164ad73fb0f53a62038ca0b7b62caba793e36006 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 19:30:05 +0000 Subject: [PATCH 24/48] drop qualifications from rules --- src/rules.jl | 54 ++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index c48b045..dfaa9ea 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -12,12 +12,12 @@ @define_diffrule Base.abs2(x) = :( $x + $x ) @define_diffrule Base.inv(x) = :( -abs2(inv($x)) ) @define_diffrule Base.log(x) = :( inv($x) ) -@define_diffrule Base.log10(x) = :( inv($x) / $(DiffRules.logten) ) -@define_diffrule Base.log2(x) = :( inv($x) / $(DiffRules.logtwo) ) +@define_diffrule Base.log10(x) = :( inv($x) / $logten ) +@define_diffrule Base.log2(x) = :( inv($x) / $logtwo ) @define_diffrule Base.log1p(x) = :( inv($x + 1) ) @define_diffrule Base.exp(x) = :( exp($x) ) -@define_diffrule Base.exp2(x) = :( exp2($x) * $(DiffRules.logtwo) ) -@define_diffrule Base.exp10(x) = :( exp10($x) * $(DiffRules.logten) ) +@define_diffrule Base.exp2(x) = :( exp2($x) * $logtwo ) +@define_diffrule Base.exp10(x) = :( exp10($x) * $logten ) @define_diffrule Base.expm1(x) = :( exp($x) ) @define_diffrule Base.sin(x) = :( cos($x) ) @define_diffrule Base.cos(x) = :( -sin($x) ) @@ -59,14 +59,14 @@ @define_diffrule Base.acoth(x) = :( inv(1 - $x^2) ) @define_diffrule Base.sinc(x) = :( cosc($x) ) @define_diffrule Base.deg2rad(x) = :( deg2rad(one($x)) ) -@define_diffrule Base.mod2pi(x) = :( isinteger($x / $(DiffRules.twoπ)) ? oftype(float($x), NaN) : one(float($x)) ) +@define_diffrule Base.mod2pi(x) = :( isinteger($x / $twoπ) ? oftype(float($x), NaN) : one(float($x)) ) @define_diffrule Base.rad2deg(x) = :( rad2deg(one($x)) ) @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) +@define_diffrule Base.abs(x) = :( _abs_deriv($x) ) # We provide this hook for special number types like `Interval` # that need their own special definition of `abs`. @@ -111,24 +111,24 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $(DiffRules.invsqrtπ) ) -@define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $(DiffRules.invsqrtπ)) * 2 ) -@define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $(DiffRules.sqrtπ)) / 2 ) -@define_diffrule SpecialFunctions.erfi(x) = :( 2 * exp($x * $x) * $(DiffRules.invsqrtπ) ) -@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * ($x * SpecialFunctions.erfcx($x) - $(DiffRules.invsqrtπ)) ) - -@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($(DiffRules.invsqrtπ) * exp(-$x^2)) ) +@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $invsqrtπ ) +@define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $sqrtπ) / 2 ) +@define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $invsqrtπ) * 2 ) +@define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $sqrtπ) / 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( 2 * exp($x * $x) * $invsqrtπ ) +@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * ($x * SpecialFunctions.erfcx($x) - $invsqrtπ) ) + +@define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) ) @define_diffrule SpecialFunctions.erfinv(x) = - :( (DiffRules.sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(DiffRules.invsqrtπ * exp(-$x^2) * 2) ) + :( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfc(x) = :( -($invsqrtπ * exp(-$x^2) * 2) ) @define_diffrule SpecialFunctions.erfcinv(x) = - :( -(DiffRules.sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 ) -@define_diffrule SpecialFunctions.erfi(x) = :( DiffRules.invsqrtπ * exp($x^2) * 2 ) + :( -($sqrtπ * exp(SpecialFunctions.erfcinv($x)^2)) / 2 ) +@define_diffrule SpecialFunctions.erfi(x) = :( $invsqrtπ * exp($x^2) * 2 ) @define_diffrule SpecialFunctions.erfcx(x) = - :( 2 * (($x * SpecialFunctions.erfcx($x)) - DiffRules.invsqrtπ) ) + :( 2 * (($x * SpecialFunctions.erfcx($x)) - $invsqrtπ) ) @define_diffrule SpecialFunctions.logerfcx(x) = - :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * DiffRules.sqrtπ)) ) + :( 2 * ($x - inv(SpecialFunctions.erfcx($x) * $sqrtπ)) ) @define_diffrule SpecialFunctions.dawson(x) = :( 1 - (2 * $x * SpecialFunctions.dawson($x)) ) @define_diffrule SpecialFunctions.digamma(x) = @@ -216,14 +216,14 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($(DiffRules.logtwo) * $x) ) -@define_diffrule NaNMath.log10(x) = :( inv($(DiffRules.logten) * $x) ) +@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) +@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) From 75b0ff6f3d730fad4eb45b3012daa55e97ac2dbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 19:31:32 +0000 Subject: [PATCH 25/48] reverted unintended change to _abs_deriv --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index dfaa9ea..4722b7f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -66,7 +66,7 @@ @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule Base.abs(x) = :( _abs_deriv($x) ) +@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) # We provide this hook for special number types like `Interval` # that need their own special definition of `abs`. From 34ee3caf6ac22c3654d9928cfd2ac47dbefd8802 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 20:00:04 +0000 Subject: [PATCH 26/48] interpolate _abs_deriv --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 4722b7f..499cf82 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -66,7 +66,7 @@ @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) -@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) +@define_diffrule Base.abs(x) = :( $(_abs_deriv)($x) ) # We provide this hook for special number types like `Interval` # that need their own special definition of `abs`. From a58951c47d0d820068e5f44992b3a64a960c56d8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 20:00:12 +0000 Subject: [PATCH 27/48] be explicit about imported irrationals --- src/DiffRules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DiffRules.jl b/src/DiffRules.jl index 558a603..67c76f2 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -2,7 +2,7 @@ __precompile__() module DiffRules -using IrrationalConstants +using IrrationalConstants: logtwo, logten, twoπ, sqrtπ, invsqrtπ include("api.jl") include("rules.jl") From e64f9e01a2d1ea568e644ab4648d6f4b01e7d865 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 14 Dec 2021 20:02:09 +0000 Subject: [PATCH 28/48] fixed tests --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 90a3aa4..9754c0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using DiffRules using Test +using IrrationalConstants: fourπ + import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) @@ -37,7 +39,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) # test for 2pi functions if "mod2pi" == string($M.$f) - goo = $(DiffRules.fourπ) + $modifier + goo = $(fourπ) + $modifier @test $T(NaN) === $deriv end end From e383cfde1f4e7a0a37b83fe44905e26725d43329 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 15 Jan 2022 10:46:28 +0000 Subject: [PATCH 29/48] fixed numerical issues in tests by adopting some changes from #79 --- Project.toml | 3 ++- test/runtests.jl | 59 +++++++++++++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index e8bdfe3..6deb636 100644 --- a/Project.toml +++ b/Project.toml @@ -17,8 +17,9 @@ SpecialFunctions = "0.10, 1.0, 2" julia = "1.3" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random"] +test = ["Test", "Random", "FiniteDifferences"] diff --git a/test/runtests.jl b/test/runtests.jl index 9754c0a..6b3c686 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using DiffRules using Test +using FiniteDifferences using IrrationalConstants: fourπ @@ -7,10 +8,8 @@ import SpecialFunctions, NaNMath, LogExpFunctions import Random Random.seed!(1) -function finitediff(f, x) - ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) - return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) -end +# Set `max_range` to avoid domain errors. +const finitediff = central_fdm(5, 1, max_range=1e-3) @testset "DiffRules" begin @testset "check rules" begin @@ -23,24 +22,27 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) - modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) - one(T) - elseif f === :log1mexp - -one(T) - elseif f === :log2mexp - -(one(T) / 2) - else - zero(T) - end @eval begin let - goo = rand($T) + $modifier + goo = if $(f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) + # avoid singularities with finite differencing + rand($T) + $T(1.5) + elseif $(f in (:log, :airyaix, :airyaiprimex)) + # avoid singularities with finite differencing + rand($T) + $T(0.5) + elseif $(f === :log1mexp) + rand($T) - one($T) + elseif $(f in (:log2mexp, :erfinv)) + rand($T) - $T(0.5) + else + rand($T) + end @test $deriv isa $T - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-3 atol=1e-3 # test for 2pi functions - if "mod2pi" == string($M.$f) - goo = $(fourπ) + $modifier - @test $T(NaN) === $deriv + if $(f === :mod2pi) + goo = 4 * pi + @test NaN === $deriv end end end @@ -49,17 +51,28 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) derivs = DiffRules.diffrule(M, f, :foo, :bar) @eval begin let - if "mod" == string($M.$f) - foo, bar = rand($T) + 13, rand($T) + 5 # make sure x/y is not integer + foo, bar = if $(f === :mod) + rand() + 13, rand() + 5 # make sure x/y is not integer + elseif $(f === :polygamma) + rand(1:10), rand() # only supports integers as first arguments + elseif $(f in (:bessely, :besselyx)) + # avoid singularities with finite differencing + rand(), rand() + 0.5 + elseif $(f === :log) + # avoid singularities with finite differencing + rand() + 1.5, rand() + elseif $(f === :^) + # avoid singularities with finite differencing + rand() + 0.5, rand() else - foo, bar = rand(1:10), rand($T) + rand(), rand() end dx, dy = $(derivs[1]), $(derivs[2]) if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-3 atol=1e-3 end if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-3 atol=1e-3 end end end From 49039158e5265fd83e312d9f32e78a2cc7e58b23 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 15 Jan 2022 10:55:36 +0000 Subject: [PATCH 30/48] relax rtol slightly since we're working with Float32 too here --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6b3c686..55bd956 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) rand($T) end @test $deriv isa $T - @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-3 atol=1e-3 + @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 # test for 2pi functions if $(f === :mod2pi) goo = 4 * pi @@ -69,10 +69,10 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end dx, dy = $(derivs[1]), $(derivs[2]) if !(isnan(dx)) - @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-3 atol=1e-3 + @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3 end if !(isnan(dy)) - @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-3 atol=1e-3 + @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3 end end end From 0111861046bc51fecfcbf23b992a2bfe1dc0bb9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:09:24 +0000 Subject: [PATCH 31/48] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6deb636..b0e646a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.9.0" +version = "1.9.1" [deps] IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" From 657c0d5befcf71743c2c22ced2a02380dd13df0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:47:26 +0000 Subject: [PATCH 32/48] test type of derivative for functions with 2 arguments --- test/runtests.jl | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 55bd956..5b74fbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,22 +52,39 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) @eval begin let foo, bar = if $(f === :mod) - rand() + 13, rand() + 5 # make sure x/y is not integer + rand($T) + 13, rand($T) + 5 # make sure x/y is not integer elseif $(f === :polygamma) - rand(1:10), rand() # only supports integers as first arguments + rand(1:10), rand($T) # only supports integers as first arguments elseif $(f in (:bessely, :besselyx)) # avoid singularities with finite differencing - rand(), rand() + 0.5 + rand($T), rand($T) + $T(0.5) elseif $(f === :log) # avoid singularities with finite differencing - rand() + 1.5, rand() + rand($T) + $T(1.5), rand($T) elseif $(f === :^) # avoid singularities with finite differencing - rand() + 0.5, rand() + rand($T) + $T(0.5), rand($T) else - rand(), rand() + rand($T), rand($T) end dx, dy = $(derivs[1]), $(derivs[2]) + + if foo isa AbstractFloat + if dx isa Complex + @test dx isa Complex{$T} + else + @test dx isa $T + end + end + + if bar isa AbstractFloat + if dy isa Complex + @test dy isa Complex{$T} + else + @test dy isa $T + end + end + if !(isnan(dx)) @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3 end From f236d23aff6a43122c76d024238bcc77cda4eccf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:47:49 +0000 Subject: [PATCH 33/48] fixed types of derivatives for mod, rem and different bessel functions --- src/rules.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 499cf82..c1a684f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -87,8 +87,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) @define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN -@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) ) -@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) ) +@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), one(z)) ), :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), -floor(z)) ) +@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), one(z)) ), :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), -trunc(z)) ) @define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) @@ -167,19 +167,19 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) #--------# @define_diffrule SpecialFunctions.besselj(ν, x) = - :NaN, :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseli(ν, x) = - :NaN, :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.bessely(ν, x) = - :NaN, :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselk(ν, x) = - :NaN, :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh1(ν, x) = - :NaN, :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh2(ν, x) = - :NaN, :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) + :(oftype(float($ν), NaN)), :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.polygamma(m, x) = - :NaN, :( SpecialFunctions.polygamma($m + 1, $x) ) + :(oftype(float($m), NaN)), :( SpecialFunctions.polygamma($m + 1, $x) ) @define_diffrule SpecialFunctions.beta(a, b) = :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b)) ), :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b)) ) @define_diffrule SpecialFunctions.logbeta(a, b) = From 40e91333cf729d021ce493bae640ae101054b566 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:49:33 +0000 Subject: [PATCH 34/48] use deg2rad Co-authored-by: David Widmann --- src/rules.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index c1a684f..4c8506f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -25,12 +25,12 @@ @define_diffrule Base.sec(x) = :( sec($x) * tan($x) ) @define_diffrule Base.csc(x) = :( -csc($x) * cot($x) ) @define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) ) -@define_diffrule Base.sind(x) = :( π * cosd($x) / 180 ) -@define_diffrule Base.cosd(x) = :( -(π * sind($x) / 180) ) -@define_diffrule Base.tand(x) = :( π * (1 + tand($x)^2) / 180 ) -@define_diffrule Base.secd(x) = :( π * secd($x) * tand($x) / 180 ) -@define_diffrule Base.cscd(x) = :( -(π * cscd($x) * cotd($x) / 180) ) -@define_diffrule Base.cotd(x) = :( -(π * (1 + cotd($x)^2) / 180) ) +@define_diffrule Base.sind(x) = :( deg2rad(cosd($x)) ) +@define_diffrule Base.cosd(x) = :( - deg2rad(sind($x)) ) +@define_diffrule Base.tand(x) = :( deg2rad(1 + tand($x)^2) ) +@define_diffrule Base.secd(x) = :( deg2rad(secd($x) * tand($x)) ) +@define_diffrule Base.cscd(x) = :( - deg2rad(cscd($x) * cotd($x)) ) +@define_diffrule Base.cotd(x) = :( - deg2rad(1 + cotd($x)^2) ) @define_diffrule Base.sinpi(x) = :( π * cospi($x) ) @define_diffrule Base.cospi(x) = :( -(π * sinpi($x)) ) @define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) ) From e4578308d79071c0c15885c0f403ff3dec17e344 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 14:59:33 +0000 Subject: [PATCH 35/48] reverted changes to + and - --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 4c8506f..3c6b1fb 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -5,8 +5,8 @@ # unary # #-------# -@define_diffrule Base.:+(x) = :( one($x) ) -@define_diffrule Base.:-(x) = :( -one($x) ) +@define_diffrule Base.:+(x) = :( 1 ) +@define_diffrule Base.:-(x) = :( -1 ) @define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) ) @define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) ) @define_diffrule Base.abs2(x) = :( $x + $x ) From 56dba972b15a6edfb24fc14860bc411104eb03e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 19:09:14 +0000 Subject: [PATCH 36/48] remove duplicate rules --- src/rules.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 3c6b1fb..c35f0c0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -111,13 +111,6 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# -@define_diffrule SpecialFunctions.erf(x) = :( 2 * exp(-$x^2) * $invsqrtπ ) -@define_diffrule SpecialFunctions.erfinv(x) = :( (exp(SpecialFunctions.erfinv($x)^2) * $sqrtπ) / 2 ) -@define_diffrule SpecialFunctions.erfc(x) = :( -(exp(-$x * $x) * $invsqrtπ) * 2 ) -@define_diffrule SpecialFunctions.erfcinv(x) = :( -(exp(SpecialFunctions.erfcinv($x)^2) * $sqrtπ) / 2 ) -@define_diffrule SpecialFunctions.erfi(x) = :( 2 * exp($x * $x) * $invsqrtπ ) -@define_diffrule SpecialFunctions.erfcx(x) = :( 2 * ($x * SpecialFunctions.erfcx($x) - $invsqrtπ) ) - @define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) ) @define_diffrule SpecialFunctions.erfinv(x) = :( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) From 80ded04c7a9514ff24fce465e40fc832eadced35 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 19:10:20 +0000 Subject: [PATCH 37/48] add back whitespace --- src/rules.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rules.jl b/src/rules.jl index c35f0c0..b727dd4 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -111,6 +111,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # unary # #-------# + @define_diffrule SpecialFunctions.erf(x) = :( 2 * ($invsqrtπ * exp(-$x^2)) ) @define_diffrule SpecialFunctions.erfinv(x) = :( ($sqrtπ * exp(SpecialFunctions.erfinv($x)^2)) / 2 ) From 7149dff1b8d03ed524af3853eec58faff54f05d3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 20:01:25 +0000 Subject: [PATCH 38/48] reverted changes to bessel functions --- src/rules.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index b727dd4..0a0cec4 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -161,19 +161,19 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) #--------# @define_diffrule SpecialFunctions.besselj(ν, x) = - :(oftype(float($ν), NaN)), :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besselj($ν - 1, $x) - SpecialFunctions.besselj($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besseli(ν, x) = - :(oftype(float($ν), NaN)), :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.besseli($ν - 1, $x) + SpecialFunctions.besseli($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.bessely(ν, x) = - :(oftype(float($ν), NaN)), :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.bessely($ν - 1, $x) - SpecialFunctions.bessely($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.besselk(ν, x) = - :(oftype(float($ν), NaN)), :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) + :NaN, :( -(SpecialFunctions.besselk($ν - 1, $x) + SpecialFunctions.besselk($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh1(ν, x) = - :(oftype(float($ν), NaN)), :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.hankelh1($ν - 1, $x) - SpecialFunctions.hankelh1($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.hankelh2(ν, x) = - :(oftype(float($ν), NaN)), :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) + :NaN, :( (SpecialFunctions.hankelh2($ν - 1, $x) - SpecialFunctions.hankelh2($ν + 1, $x)) / 2 ) @define_diffrule SpecialFunctions.polygamma(m, x) = - :(oftype(float($m), NaN)), :( SpecialFunctions.polygamma($m + 1, $x) ) + :NaN, :( SpecialFunctions.polygamma($m + 1, $x) ) @define_diffrule SpecialFunctions.beta(a, b) = :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($a) - SpecialFunctions.digamma($a + $b)) ), :( SpecialFunctions.beta($a, $b)*(SpecialFunctions.digamma($b) - SpecialFunctions.digamma($a + $b)) ) @define_diffrule SpecialFunctions.logbeta(a, b) = From 98238974d4483c03f59deea1ba03a72664eb9002 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 20:24:22 +0000 Subject: [PATCH 39/48] only test return-type having the correct promotion behavior --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5b74fbd..483a923 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,7 +37,9 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) else rand($T) end - @test $deriv isa $T + # We're happy with types with the correct promotion behavior, e.g. + # it's fine to return `1` as a derivative despite input being `Float64`. + @test one($T) * $deriv isa $T @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 # test for 2pi functions if $(f === :mod2pi) From 97117efe5ff9f639b5b49e452171bbd7c05effbf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 20:24:38 +0000 Subject: [PATCH 40/48] only test type for 2 argument functions whose derivatives aren't NaN --- test/runtests.jl | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 483a923..63f68a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,28 +70,29 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) rand($T), rand($T) end dx, dy = $(derivs[1]), $(derivs[2]) - - if foo isa AbstractFloat - if dx isa Complex - @test dx isa Complex{$T} - else - @test dx isa $T - end - end - - if bar isa AbstractFloat - if dy isa Complex - @test dy isa Complex{$T} - else - @test dy isa $T - end - end - if !(isnan(dx)) @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3 + + # Check type, if applicable. + if foo isa AbstractFloat + if dx isa Complex + @test dx isa Complex{$T} + else + @test dx isa $T + end + end end if !(isnan(dy)) @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3 + + # Check type, if applicable. + if bar isa AbstractFloat + if dy isa Complex + @test dy isa Complex{$T} + else + @test dy isa $T + end + end end end end From e29056aece978cda2210abde671b67a286ea6286 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 20:49:56 +0000 Subject: [PATCH 41/48] fixed rules of mod and rem --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 0a0cec4..46183f8 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -87,8 +87,8 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) @define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN -@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), one(z)) ), :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), -floor(z)) ) -@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), one(z)) ), :( z = $x / $y; ifelse(isinteger(z), oftype(z, NaN), -trunc(z)) ) +@define_diffrule Base.mod(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -floor(float(z))) ) +@define_diffrule Base.rem(x, y) = :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), one(float(z))) ), :( z = $x / $y; ifelse(isinteger(z), oftype(float(z), NaN), -trunc(float(z))) ) @define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) From 773a572fe41704ac4d0a1bd169faa612c1f30e54 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 27 Jan 2022 20:50:27 +0000 Subject: [PATCH 42/48] make each rule its own testset for easier debugging --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 63f68a9..b87e77c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,7 @@ const finitediff = central_fdm(5, 1, max_range=1e-3) non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, :ifelse, 3)] -for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) +@testset "($M, $f, $arity)" for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) for T in [Float32, Float64] (M, f, arity) ∈ non_diffeable_arg_functions && continue if arity == 1 From cdbec02b3c0c88a8c7d3f750c920c4a9fc596b3b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 15:20:53 +0000 Subject: [PATCH 43/48] reverted changes to multiple NaNMath rules --- src/rules.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 46183f8..d1ca0a2 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -210,14 +210,14 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule NaNMath.sqrt(x) = :( inv(2 * NaNMath.sqrt($x)) ) @define_diffrule NaNMath.sin(x) = :( NaNMath.cos($x) ) @define_diffrule NaNMath.cos(x) = :( -NaNMath.sin($x) ) -@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), oftype($x, 2)) ) -@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, oftype($x, 2)))) ) -@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, oftype($x, 2)) - 1)) ) -@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, oftype($x, 2))) ) +@define_diffrule NaNMath.tan(x) = :( 1 + NaNMath.pow(NaNMath.tan($x), 2) ) +@define_diffrule NaNMath.asin(x) = :( inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acos(x) = :( -inv(NaNMath.sqrt(1 - NaNMath.pow($x, 2))) ) +@define_diffrule NaNMath.acosh(x) = :( inv(NaNMath.sqrt(NaNMath.pow($x, 2) - 1)) ) +@define_diffrule NaNMath.atanh(x) = :( inv(1 - NaNMath.pow($x, 2)) ) @define_diffrule NaNMath.log(x) = :( inv($x) ) -@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) -@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) +@define_diffrule NaNMath.log2(x) = :( inv($logtwo * $x) ) +@define_diffrule NaNMath.log10(x) = :( inv($logten * $x) ) @define_diffrule NaNMath.log1p(x) = :( inv($x + 1) ) @define_diffrule NaNMath.lgamma(x) = :( SpecialFunctions.digamma($x) ) From 6ddff442945493ec7c6dc293b1aa09d89b030e83 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 16:02:27 +0000 Subject: [PATCH 44/48] use more explicit promotion in tests Co-authored-by: David Widmann --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b87e77c..3195ec1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,7 +39,7 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, end # We're happy with types with the correct promotion behavior, e.g. # it's fine to return `1` as a derivative despite input being `Float64`. - @test one($T) * $deriv isa $T + @test promote_type(typeof($deriv), $T) === $T @test $deriv ≈ finitediff($M.$f, goo) rtol=1e-2 atol=1e-3 # test for 2pi functions if $(f === :mod2pi) From 777d1a7b072f3130d39fe6d23ff925c20677b606 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 16:03:48 +0000 Subject: [PATCH 45/48] check promotion of real instead of specific check for Complex Co-authored-by: David Widmann --- test/runtests.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3195ec1..eaa1bae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,11 +75,7 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # Check type, if applicable. if foo isa AbstractFloat - if dx isa Complex - @test dx isa Complex{$T} - else - @test dx isa $T - end + @test promote_type(typeof(real(dx)), $T) === $T end end if !(isnan(dy)) @@ -87,11 +83,7 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, # Check type, if applicable. if bar isa AbstractFloat - if dy isa Complex - @test dy isa Complex{$T} - else - @test dy isa $T - end + @test promote_type(typeof(real(dy)), $T) === $T end end end From f937593ff328d5f8f5bc7264fbd39dca4c60910d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 16:08:49 +0000 Subject: [PATCH 46/48] reverted unnecessary change --- test/runtests.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index eaa1bae..491d5d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -161,12 +161,12 @@ end @test DiffRules.hasdiffrule(:Base, :ifelse, 3) derivs = DiffRules.diffrule(:Base, :ifelse, :foo, :bar, :goo) for cond in [true, false] -@eval begin -foo = $cond -bar, gee = randn(2) -dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) -@test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) -@test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) -end + @eval begin + foo = $cond + bar, gee = randn(2) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + @test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) + @test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) + end end =# From 5d7b63560926b2dab44550ee0daa5b14aaf9a795 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 16:10:10 +0000 Subject: [PATCH 47/48] reverted unnecessary changes --- test/runtests.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 491d5d0..7f86108 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -93,17 +93,17 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, @test DiffRules.hasdiffrule(M, f, 3) derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) @eval begin - foo, bar, goo = randn(3) - dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) - end - if !(isnan(dz)) - @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) - end + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) + end + if !(isnan(dz)) + @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) + end end =# end From 6e564bc799bd5a54cc613ba23857e22dba79bd8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jan 2022 18:28:22 +0000 Subject: [PATCH 48/48] dont check if AbstractFloat in tests --- test/runtests.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7f86108..6618a36 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,17 +74,13 @@ non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, @test dx ≈ finitediff(z -> $M.$f(z, bar), foo) rtol=1e-2 atol=1e-3 # Check type, if applicable. - if foo isa AbstractFloat - @test promote_type(typeof(real(dx)), $T) === $T - end + @test promote_type(typeof(real(dx)), $T) === $T end if !(isnan(dy)) @test dy ≈ finitediff(z -> $M.$f(foo, z), bar) rtol=1e-2 atol=1e-3 # Check type, if applicable. - if bar isa AbstractFloat - @test promote_type(typeof(real(dy)), $T) === $T - end + @test promote_type(typeof(real(dy)), $T) === $T end end end