diff --git a/Project.toml b/Project.toml index 279bd47..21e4474 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.18" +version = "0.12.19" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/grad.jl b/src/grad.jl index f02f849..ebbc533 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -85,4 +85,4 @@ end Compute the gradient of `f` for any `xs` for which [`to_vec`](@ref) is defined. """ -grad(fdm, f, xs...) = j′vp(fdm, f, 1, xs...) # `j′vp` with seed of 1 +grad(fdm, f, xs...) = j′vp(fdm, f, one(f(xs...)), xs...) # `j′vp` with seed of 1 diff --git a/src/to_vec.jl b/src/to_vec.jl index 838078d..cee962c 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -18,6 +18,11 @@ function to_vec(z::Complex) return [real(z), imag(z)], Complex_from_vec end +function to_vec(x::Integer) + Integer_from_vec(_) = x + return Bool[], Integer_from_vec +end + # Base case -- if x is already a Vector{<:Real} there's no conversion necessary. to_vec(x::Vector{<:Real}) = (x, identity) diff --git a/test/to_vec.jl b/test/to_vec.jl index 7e430c8..b490b11 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -68,6 +68,14 @@ function test_to_vec(x::T; check_inferred=true) where {T} end @testset "to_vec" begin + + # Integers are non-differentiable. to_vec should only preserve the differentiable bits + # of a type so that they can be appropriately perturbed. + @testset "Int" begin + @test isempty(to_vec(4)[1]) + test_to_vec(5) + end + @testset "$T" for T in (Float32, ComplexF32, Float64, ComplexF64) if T == Float64 test_to_vec(1.0)