diff --git a/.github/workflows/TagBot.yml b/.github/workflows/TagBot.yml index ed607f7..f49313b 100644 --- a/.github/workflows/TagBot.yml +++ b/.github/workflows/TagBot.yml @@ -1,7 +1,5 @@ name: TagBot on: - schedule: - - cron: 0 * * * * issue_comment: types: - created diff --git a/.gitignore b/.gitignore index 8c960ec..97e6a6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.jl.cov *.jl.*.cov *.jl.mem +/Manifest.toml diff --git a/Project.toml b/Project.toml index 818be1c..5751f17 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,13 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.3.1" [deps] +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +LogExpFunctions = "0.3" NaNMath = "0.3" SpecialFunctions = "0.8, 0.9, 0.10, 1.0" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index 23f89bd..ccdf86b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,12 +1,20 @@ using Documenter, DiffRules +DocMeta.setdocmeta!( + DiffRules, + :DocTestSetup, + :(using DiffRules); + recursive=true, +) + makedocs(modules=[DiffRules], - doctest = false, sitename = "DiffRules", pages = ["Documentation" => "index.md"], format = Documenter.HTML( prettyurls = get(ENV, "CI", nothing) == "true" ), + strict=true, + checkdocs=:exports, ) deploydocs(; repo="github.com/JuliaDiff/DiffRules.jl", push_preview=true) diff --git a/src/DiffRules.jl b/src/DiffRules.jl index b70cae5..e51f4d9 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -2,6 +2,8 @@ __precompile__() module DiffRules +import LogExpFunctions + include("api.jl") include("rules.jl") diff --git a/src/api.jl b/src/api.jl index 53500a0..59236c4 100644 --- a/src/api.jl +++ b/src/api.jl @@ -16,12 +16,13 @@ interpolated wherever they are used on the RHS. Note that differentiation rules are purely symbolic, so no type annotations should be used. -Examples: - - @define_diffrule Base.cos(x) = :(-sin(\$x)) - @define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2)) - @define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x)) +# Examples +```julia +@define_diffrule Base.cos(x) = :(-sin(\$x)) +@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2)) +@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x)) +``` """ macro define_diffrule(def) @assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side" @@ -50,19 +51,18 @@ interpolated into the returned expression. In the `n`-ary case, an `n`-tuple of expressions will be returned where the `i`th expression is the derivative of `f` w.r.t the `i`th argument. -Examples: - - julia> DiffRules.diffrule(:Base, :sin, 1) - :(cos(1)) +# Examples - julia> DiffRules.diffrule(:Base, :sin, :x) - :(cos(x)) +```jldoctest +julia> DiffRules.diffrule(:Base, :sin, 1) +:(cos(1)) - julia> DiffRules.diffrule(:Base, :sin, :(x * y^2)) - :(cos(x * y ^ 2)) +julia> DiffRules.diffrule(:Base, :sin, :x) +:(cos(x)) - julia> DiffRules.diffrule(:Base, :^, :(x + 2), :c) - (:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2))) +julia> DiffRules.diffrule(:Base, :sin, :(x * y^2)) +:(cos(x * y ^ 2)) +``` """ diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...) @@ -74,41 +74,109 @@ otherwise. Here, `arity` refers to the number of arguments accepted by `f`. -Examples: +# Examples - julia> DiffRules.hasdiffrule(:Base, :sin, 1) - true +```jldoctest +julia> DiffRules.hasdiffrule(:Base, :sin, 1) +true - julia> DiffRules.hasdiffrule(:Base, :sin, 2) - false +julia> DiffRules.hasdiffrule(:Base, :sin, 2) +false - julia> DiffRules.hasdiffrule(:Base, :-, 1) - true +julia> DiffRules.hasdiffrule(:Base, :-, 1) +true - julia> DiffRules.hasdiffrule(:Base, :-, 2) - true +julia> DiffRules.hasdiffrule(:Base, :-, 2) +true - julia> DiffRules.hasdiffrule(:Base, :-, 3) - false +julia> DiffRules.hasdiffrule(:Base, :-, 3) +false +``` """ hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity)) +# show a deprecation warning if `filter_modules` in `diffrules()` is specified implicitly +# we use a custom singleton to figure out if the keyword argument was set explicitly +struct DefaultFilterModules end + +function deprecated_modules(modules) + return if modules isa DefaultFilterModules + Base.depwarn( + "the implicit keyword argument " * + "`filter_modules=(:Base, :SpecialFunctions, :NaNMath)` in `diffrules()` is " * + "deprecated and will be changed to `filter_modules=nothing` in an upcoming " * + "breaking release of DiffRules (i.e., `diffrules()` will return all rules " * + "defined in DiffRules)", + :diffrules, + ) + (:Base, :SpecialFunctions, :NaNMath) + else + modules + end +end + """ - diffrules() + diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath)) -Return a list of keys that can be used to access all defined differentiation rules. +Return a list of keys that can be used to access all defined differentiation rules for +modules in `filter_modules`. Each key is of the form `(M::Symbol, f::Symbol, arity::Int)`. - -Here, `arity` refers to the number of arguments accepted by `f`. - -Examples: - - julia> first(DiffRules.diffrules()) - (:Base, :asind, 1) - +Here, `arity` refers to the number of arguments accepted by `f` and `M` is one of the +modules in `filter_modules`. + +To include all rules, specify `filter_modules = nothing`. + +!!! note + Calling `diffrules()` with the implicit default keyword argument `filter_modules` + does *not* return all rules defined by this package but rather only rules for the + packages for which DiffRules 1.0 provided rules. This is done in order to not to + break downstream packages that assumed this list would never change. + It is planned to change `diffrules()` to return all rules, i.e., to use the + default keyword argument `filter_modules=nothing`, in an upcoming breaking release + of DiffRules. + +# Examples + +```jldoctest +julia> first(DiffRules.diffrules()) +(:Base, :log2, 1) +``` + +If you call `diffrules()`, only rules for Base, SpecialFunctions, and +NaNMath are returned but no rules for LogExpFunctions: +```jldoctest +julia> any(M === :LogExpFunctions for (M, _, _) in DiffRules.diffrules()) +false +``` + +If you set `filter_modules=nothing`, all rules defined in DiffRules are +returned and in particular also rules for LogExpFunctions: +```jldoctest +julia> any( + M === :LogExpFunctions + for (M, _, _) in DiffRules.diffrules(; filter_modules=nothing) + ) +true +``` + +If you set `filter_modules=(:Base,)` only rules for functions in Base are +returned: +```jldoctest +julia> all(M === :Base for (M, _, _) in DiffRules.diffrules(; filter_modules=(:Base,))) +true +``` """ -diffrules() = keys(DEFINED_DIFFRULES) +function diffrules(; filter_modules=DefaultFilterModules()) + modules = deprecated_modules(filter_modules) + return if modules === nothing + keys(DEFINED_DIFFRULES) + else + Iterators.filter(keys(DEFINED_DIFFRULES)) do (M, _, _) + return M in modules + end + end +end # For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a # `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will diff --git a/src/rules.jl b/src/rules.jl index 6ace211..88cb5af 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -232,3 +232,30 @@ end :(ifelse(($y > $x) | (signbit($y) < signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($y), zero($y)))) @define_diffrule NaNMath.min(x, y) = :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), one($x), zero($x)), ifelse(isnan($x), zero($x), one($x)))), :(ifelse(($y < $x) | (signbit($y) > signbit($x)), ifelse(isnan($y), zero($y), one($y)), ifelse(isnan($x), one($x), zero($x)))) + +################### +# LogExpFunctions # +################### + +# unary +@define_diffrule LogExpFunctions.xlogx(x) = :(1 + log($x)) +@define_diffrule LogExpFunctions.logistic(x) = :(z = LogExpFunctions.logistic($x); z * (1 - z)) +@define_diffrule LogExpFunctions.logit(x) = :(inv($x * (1 - $x))) +@define_diffrule LogExpFunctions.log1psq(x) = :(2 * $x / (1 + $x^2)) +@define_diffrule LogExpFunctions.log1pexp(x) = :(LogExpFunctions.logistic($x)) +@define_diffrule LogExpFunctions.log1mexp(x) = :(-exp($x - LogExpFunctions.log1mexp($x))) +@define_diffrule LogExpFunctions.log2mexp(x) = :(-exp($x - LogExpFunctions.log2mexp($x))) +@define_diffrule LogExpFunctions.logexpm1(x) = :(exp($x - LogExpFunctions.logexpm1($x))) + +# binary +@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y) +@define_diffrule LogExpFunctions.logaddexp(x, y) = + :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) +@define_diffrule LogExpFunctions.logsubexp(x, y) = + :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), + :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) + +# only defined in LogExpFunctions >= 0.3.2 +if isdefined(LogExpFunctions, :xlog1py) + @define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y)) +end diff --git a/test/runtests.jl b/test/runtests.jl index e3da95f..533cf93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,49 +1,58 @@ -if VERSION < v"0.7-" - using Base.Test - srand(1) -else - using Test - import Random - Random.seed!(1) -end -import SpecialFunctions, NaNMath using DiffRules +using Test +import SpecialFunctions, NaNMath, LogExpFunctions +import Random +Random.seed!(1) function finitediff(f, x) ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) end +@testset "DiffRules" begin +@testset "check rules" begin non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] -for (M, f, arity) in DiffRules.diffrules() +for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) (M, f, arity) ∈ non_numeric_arg_functions && continue if arity == 1 @test DiffRules.hasdiffrule(M, f, 1) deriv = DiffRules.diffrule(M, f, :goo) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 + modifier = if f in (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) + 1.0 + elseif f === :log1mexp + -1.0 + elseif f === :log2mexp + -0.5 + else + 0.0 + end @eval begin - goo = rand() + $modifier - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) - # test for 2pi functions - if "mod2pi" == string($M.$f) - goo = 4pi + $modifier - @test NaN === $deriv + let + goo = rand() + $modifier + @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + # test for 2pi functions + if "mod2pi" == string($M.$f) + goo = 4pi + $modifier + @test NaN === $deriv + end end end elseif arity == 2 @test DiffRules.hasdiffrule(M, f, 2) derivs = DiffRules.diffrule(M, f, :foo, :bar) @eval begin - foo, bar = rand(1:10), rand() - dx, dy = $(derivs[1]), $(derivs[2]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + let + foo, bar = rand(1:10), rand() + dx, dy = $(derivs[1]), $(derivs[2]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + end end end elseif arity == 3 @@ -72,14 +81,29 @@ derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y) for xtype in [:Float64, :BigFloat, :Int64] for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest] @eval begin - x = $xtype(rand(1 : 10)) - y = $mode - dx, dy = $(derivs[1]), $(derivs[2]) - @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) - @test isnan(dy) + let + x = $xtype(rand(1 : 10)) + y = $mode + dx, dy = $(derivs[1]), $(derivs[2]) + @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) + @test isnan(dy) + end end end end +end + + @testset "diffrules" begin + rules = @test_deprecated(DiffRules.diffrules()) + @test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath)) + + rules = DiffRules.diffrules(; filter_modules=nothing) + @test Set(M for (M, _, _) in rules) == Set((:Base, :SpecialFunctions, :NaNMath, :LogExpFunctions)) + + rules = DiffRules.diffrules(; filter_modules=(:Base, :LogExpFunctions)) + @test Set(M for (M, _, _) in rules) == Set((:Base, :LogExpFunctions)) + end +end # Test ifelse separately as first argument is boolean #=