Skip to content

Commit 7b1c31e

Browse files
Added ldexp to list of diffrules (#73)
* Update rules.jl Added a rule for the function ldexp https://docs.julialang.org/en/v1/base/math/#Base.Math.ldexp * Update src/rules.jl use exp2 instead of 2^x to avoid overflows, indicate the derivative w.r.t. the second argument of ldexp as not defined since that argument is of integer type. Co-authored-by: David Widmann <[email protected]> * added special test for ldexp since its second argument is required to be an integer and thus non-differentiable. Changed variable name from "non_numeric_arg_functions" to "non_diffable_arg_functions" to be more general and account for the numeric but non-differentiable nature of the second argument of ldexp. * Delete .vscode directory * Update tests * Bump version Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent ea79b94 commit 7b1c31e

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DiffRules"
22
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
3-
version = "1.8.0"
3+
version = "1.9.0"
44

55
[deps]
66
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"

src/rules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
8686
@define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) )
8787
@define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) )
8888
@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) )
89+
@define_diffrule Base.ldexp(x, y) = :( exp2($y) ), :NaN
8990

9091
@define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) )
9192
@define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) )

test/runtests.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ end
1313
@testset "DiffRules" begin
1414
@testset "check rules" begin
1515

16-
non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)]
16+
non_diffeable_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ldexp, 2), (:Base, :ifelse, 3)]
1717

1818
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
19-
(M, f, arity) non_numeric_arg_functions && continue
19+
(M, f, arity) non_diffeable_arg_functions && continue
2020
if arity == 1
2121
@test DiffRules.hasdiffrule(M, f, 1)
2222
deriv = DiffRules.diffrule(M, f, :goo)
@@ -95,6 +95,23 @@ for xtype in [:Float64, :BigFloat, :Int64]
9595
end
9696
end
9797
end
98+
99+
# Treat ldexp separately because of its integer second argument:
100+
derivs = DiffRules.diffrule(:Base, :ldexp, :x, :y)
101+
for xtype in [:Float64, :BigFloat]
102+
for ytype in [:Integer, :UInt64, :Int64]
103+
@eval begin
104+
let
105+
x = rand($xtype)
106+
y = $ytype(rand(1 : 10))
107+
dx, dy = $(derivs[1]), $(derivs[2])
108+
@test isapprox(dx, finitediff(z -> ldexp(z, y), x), rtol=0.05)
109+
@test isnan(dy)
110+
end
111+
end
112+
end
113+
end
114+
98115
end
99116

100117
@testset "diffrules" begin

0 commit comments

Comments
 (0)