diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index dd821e683..cdeee2dba 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -16,4 +16,4 @@ jobs: - name: CompatHelper.main() env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: julia -e 'using CompatHelper; CompatHelper.main()' + run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "test"])' diff --git a/Project.toml b/Project.toml index 27662e132..ea52d71a3 100644 --- a/Project.toml +++ b/Project.toml @@ -22,15 +22,3 @@ StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9" ZygoteRules = "0.2" julia = "1.3" - -[extras] -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"] diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index f89db37ae..5ffa58654 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel using Compat using Requires using Distances, LinearAlgebra -using SpecialFunctions: logabsgamma, besselk -using ZygoteRules: @adjoint +using SpecialFunctions: logabsgamma, besselk, polygamma +using ZygoteRules: @adjoint, pullback using StatsFuns: logtwo using InteractiveUtils: subtypes using StatsBase diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 2adda86ae..44b5eb989 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -17,12 +17,11 @@ end @inline function kappa(κ::MaternKernel, d::Real) ν = first(κ.ν) - iszero(d) ? one(d) : - exp( - (one(d) - ν) * logtwo - logabsgamma(ν)[1] + - ν * log(sqrt(2ν) * d) + - log(besselk(ν, sqrt(2ν) * d)) - ) + iszero(d) ? one(d) : _matern(ν, d) +end + +function _matern(ν::Real, d::Real) + exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d))) end metric(::MaternKernel) = Euclidean() diff --git a/src/distances/delta.jl b/src/distances/delta.jl index b986ef73f..54da36ad5 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,12 +1,14 @@ struct Delta <: Distances.PreMetric end -@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T} +@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T} @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end return a == b end +Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb) + @inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) -@inline (dist::Delta)(a::Number,b::Number) = a == b +@inline (dist::Delta)(a::Number, b::Number) = a == b diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index 7d75266db..880c494df 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,13 +1,15 @@ struct DotProduct <: Distances.PreMetric end # struct DotProduct <: Distances.UnionSemiMetric end -@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T} +@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) @boundscheck if length(a) != length(b) throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) end return dot(a,b) end +Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) + @inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b @inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b) @inline (dist::DotProduct)(a::Number,b::Number) = a * b diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 7276e2e48..f4bdd6b97 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r @inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) @inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r)) -@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T} +Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb) + +@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T} @boundscheck if (length(a) != length(b)) || length(a) != length(d.r) throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))")) end diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index d9bf019a9..d5231c1bf 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -24,9 +24,9 @@ dim(t::ARDTransform) = length(t.v) (t::ARDTransform)(x::Real) = first(t.v) * x (t::ARDTransform)(x) = t.v .* x -Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x -Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X) -Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X) +_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x +_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X) +_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X) Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v) diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index d8d3bc1f5..b1ed93ffb 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -27,7 +27,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor (t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x) -function Base.map(t::ChainTransform, x::AbstractVector) +function _map(t::ChainTransform, x::AbstractVector) return foldl((x, t) -> map(t, x), t.transforms; init=x) end diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index 5c3729dc3..c1d09b418 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -15,9 +15,9 @@ end (t::FunctionTransform)(x) = t.f(x) -Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x) -Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) -Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) +_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x) +_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) +_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) duplicate(t::FunctionTransform,f) = FunctionTransform(f) diff --git a/src/transform/lineartransform.jl b/src/transform/lineartransform.jl index 43224f90c..dcbd55873 100644 --- a/src/transform/lineartransform.jl +++ b/src/transform/lineartransform.jl @@ -27,9 +27,9 @@ end (t::LinearTransform)(x::Real) = vec(t.A * x) (t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x -Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x') -Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X) -Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A') +_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x') +_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X) +_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A') function Base.show(io::IO, t::LinearTransform) print(io::IO, "Linear transform (size(A) = ", size(t.A), ")") diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index af09b27ef..37aa1fef9 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -19,9 +19,9 @@ set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ] (t::ScaleTransform)(x) = first(t.s) .* x -Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x -Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X) -Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X) +_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x +_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X) +_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X) Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s)) diff --git a/src/transform/selecttransform.jl b/src/transform/selecttransform.jl index 66631ff13..608e55b1d 100644 --- a/src/transform/selecttransform.jl +++ b/src/transform/selecttransform.jl @@ -25,7 +25,7 @@ duplicate(t::SelectTransform,θ) = t (t::SelectTransform)(x::AbstractVector) = view(x, t.select) -Base.map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :)) -Base.map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select)) +_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :)) +_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select)) Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")") diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 7d2bbe22c..b6ab0f397 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -5,12 +5,8 @@ include("functiontransform.jl") include("selecttransform.jl") include("chaintransform.jl") -""" - apply(t::Transform, x; obsdim::Int=defaultobs) -Apply the transform `t` vector-wise on the array `x` -""" -apply +Base.map(t::Transform, x::AbstractVector) = _map(t, x) """ IdentityTransform() @@ -20,7 +16,7 @@ Return exactly the input struct IdentityTransform <: Transform end (t::IdentityTransform)(x) = x -Base.map(::IdentityTransform, x::AbstractVector) = x +_map(::IdentityTransform, x::AbstractVector) = x ### TODO Maybe defining adjoints could help but so far it's not working diff --git a/src/utils.jl b/src/utils.jl index ab738c165..ed11f2428 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,7 @@ hadamard(x, y) = x .* y +loggamma(x) = first(logabsgamma(x)) + # Macro for checking arguments macro check_args(K, param, cond, desc=string(cond)) quote @@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector) )) end end - diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 7c6311477..a95be8142 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -1,9 +1,78 @@ +## Adjoints Delta +@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) + evaluate(s, x, y), Δ -> begin + (nothing, nothing, nothing) + end +end + +@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = pairwise(d, X, Y; dims = dims) + if dims == 1 + return D, Δ -> (nothing, nothing, nothing) + else + return D, Δ -> (nothing, nothing, nothing) + end +end + +@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2) + D = pairwise(d, X; dims = dims) + if dims == 1 + return D, Δ -> (nothing, nothing) + else + return D, Δ -> (nothing, nothing) + end +end + +## Adjoints DotProduct @adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) dot(x, y), Δ -> begin (nothing, Δ .* y, Δ .* x) end end +@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = pairwise(d, X, Y; dims = dims) + if dims == 1 + return D, Δ -> (nothing, Δ * Y, (X' * Δ)') + else + return D, Δ -> (nothing, (Δ * Y')', X * Δ) + end +end + +@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2) + D = pairwise(d, X; dims = dims) + if dims == 1 + return D, Δ -> (nothing, 2 * Δ * X) + else + return D, Δ -> (nothing, 2 * X * Δ) + end +end + +## Adjoints Sinus +@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) + d = (x - y) + sind = sinpi.(d) + val = sum(abs2, sind ./ s.r) + gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) + val, Δ -> begin + ((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) + end +end + +@adjoint function loggamma(x) + first(logabsgamma(x)) , Δ -> (Δ .* polygamma(0, x), ) +end + +@adjoint function kappa(κ::MaternKernel, d::Real) + ν = first(κ.ν) + val, grad = pullback(_matern, ν, d) + return ((iszero(d) ? one(d) : val), + Δ -> begin + ∇ = grad(Δ) + return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2]) + end) +end + @adjoint function ColVecs(X::AbstractMatrix) back(Δ::NamedTuple) = (Δ.X,) back(Δ::AbstractMatrix) = (Δ,) @@ -22,10 +91,10 @@ end return RowVecs(X), back end -# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) -# d = evaluate(s, x, y) -# s = sum(sin.(π*(x-y))) -# d, Δ -> begin -# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d) -# end -# end +@adjoint function Base.map(t::Transform, X::ColVecs) + pullback(_map, t, X) +end + +@adjoint function Base.map(t::Transform, X::RowVecs) + pullback(_map, t, X) +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..ba243cd37 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,24 @@ +[deps] +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Distances = "0.9" +FiniteDifferences = "0.10" +Flux = "0.10" +ForwardDiff = "0.10" +Kronecker = "0.4" +PDMats = "0.9" +ReverseDiff = "1.2" +SpecialFunctions = "0.10" +Zygote = "0.4" diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index 9a824e287..308fb84b6 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -5,6 +5,7 @@ @test kappa(k,2.0) == 0.0 @test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta() @test repr(k) == "Zero Kernel" + test_ADs(ZeroKernel) end @testset "WhiteKernel" begin k = WhiteKernel() @@ -14,6 +15,7 @@ @test EyeKernel == WhiteKernel @test metric(WhiteKernel()) == KernelFunctions.Delta() @test repr(k) == "White Kernel" + test_ADs(WhiteKernel) end @testset "ConstantKernel" begin c = 2.0 @@ -24,5 +26,6 @@ @test metric(ConstantKernel()) == KernelFunctions.Delta() @test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" + test_ADs(c->ConstantKernel(c=first(c)), [c]) end end diff --git a/test/basekernels/cosine.jl b/test/basekernels/cosine.jl index 5874c6ba7..bf4c060b4 100644 --- a/test/basekernels/cosine.jl +++ b/test/basekernels/cosine.jl @@ -12,4 +12,5 @@ @test kappa(k,x) ≈ cospi(x) atol=1e-5 @test k(v1, v2) ≈ cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5 @test repr(k) == "Cosine Kernel" + test_ADs(CosineKernel) end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index d87289711..e890a3a15 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -14,6 +14,7 @@ @test SEKernel == SqExponentialKernel @test repr(k) == "Squared Exponential Kernel" @test KernelFunctions.iskroncompatible(k) == true + test_ADs(SEKernel) end @testset "ExponentialKernel" begin k = ExponentialKernel() @@ -24,6 +25,7 @@ @test repr(k) == "Exponential Kernel" @test LaplacianKernel == ExponentialKernel @test KernelFunctions.iskroncompatible(k) == true + test_ADs(ExponentialKernel) end @testset "GammaExponentialKernel" begin γ = 2.0 @@ -36,7 +38,8 @@ @test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ))" @test KernelFunctions.iskroncompatible(k) == true - + test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Zygote gradient given γ" #Coherence : @test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2) @test GammaExponentialKernel(γ=0.5)(v1,v2) ≈ ExponentialKernel()(v1,v2) diff --git a/test/basekernels/exponentiated.jl b/test/basekernels/exponentiated.jl index 17b625a94..a8c117b3b 100644 --- a/test/basekernels/exponentiated.jl +++ b/test/basekernels/exponentiated.jl @@ -10,4 +10,5 @@ @test k(v1,v2) ≈ exp(dot(v1,v2)) @test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct() @test repr(k) == "Exponentiated Kernel" + test_ADs(ExponentiatedKernel) end diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 645fdc088..77ed3b537 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -1,12 +1,13 @@ @testset "FBM" begin + rng = MersenneTwister(42) h = 0.3 k = FBMKernel(h = h) - v1 = rand(3); v2 = rand(3) + v1 = rand(rng, 3); v2 = rand(rng, 3) @test k(v1,v2) ≈ (sqeuclidean(v1, zero(v1))^h + sqeuclidean(v2, zero(v2))^h - sqeuclidean(v1-v2, zero(v1-v2))^h)/2 atol=1e-5 # kernelmatrix tests - m1 = rand(3,3) - m2 = rand(3,3) + m1 = rand(rng, 3, 3) + m2 = rand(rng, 3, 3) Kref = kernelmatrix(k, m1, m1) @test kernelmatrix(k, m1) ≈ Kref atol=1e-5 K = zeros(3, 3) @@ -16,9 +17,11 @@ kernelmatrix!(K, k, m1) @test K ≈ Kref atol=1e-5 - x1 = rand() - x2 = rand() + x1 = rand(rng) + x2 = rand(rng) @test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] ≈ k(x1, x2) atol=1e-5 @test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))" + test_ADs(FBMKernel, ADs = [:ReverseDiff]) + @test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote" end diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index b9d47560c..26f610cae 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -17,4 +17,6 @@ @test k.ell ≈ 1.0 atol=1e-5 @test k.p ≈ 1.0 atol=1e-5 @test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)" + test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Tests failing for Zygote on differentiating through ell and p" end diff --git a/test/basekernels/maha.jl b/test/basekernels/maha.jl index 748b733fc..e5ecba3d0 100644 --- a/test/basekernels/maha.jl +++ b/test/basekernels/maha.jl @@ -11,4 +11,6 @@ @test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P)) @test kappa(ExponentialKernel(), x) == kappa(k, x) @test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))" + # test_ADs(P -> MahalanobisKernel(P), P) + @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" end diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index af58dc470..a37ea29ba 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -14,6 +14,8 @@ @test metric(MaternKernel()) == Euclidean() @test metric(MaternKernel(ν=2.0)) == Euclidean() @test repr(k) == "Matern Kernel (ν = $(ν))" + # test_ADs(x->MaternKernel(nu=first(x)),[ν]) + @test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)" end @testset "Matern32Kernel" begin k = Matern32Kernel() @@ -22,6 +24,7 @@ @test kappa(Matern32Kernel(),x) == kappa(k,x) @test metric(Matern32Kernel()) == Euclidean() @test repr(k) == "Matern 3/2 Kernel" + test_ADs(Matern32Kernel) end @testset "Matern52Kernel" begin k = Matern52Kernel() @@ -30,6 +33,7 @@ @test kappa(Matern52Kernel(),x) == kappa(k,x) @test metric(Matern52Kernel()) == Euclidean() @test repr(k) == "Matern 5/2 Kernel" + test_ADs(Matern52Kernel) end @testset "Coherence Materns" begin @test kappa(MaternKernel(ν=0.5),x) ≈ kappa(ExponentialKernel(),x) diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index 4617bd47d..6d6bb272c 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -43,5 +43,6 @@ @test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4)) @test k([x1], [x2]) ≈ k(x1, x2) atol=1e-5 - + test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Zygote uncompatible with BaseKernel" end diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index c7056f75d..a4a2459db 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -7,4 +7,6 @@ @test k(v1, v2) == k(v2, v1) @test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2) @test repr(k) == "Periodic Kernel, length(r) = $(length(r)))" + # test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" end diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index 329d983ee..c1d0f633f 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -29,7 +29,9 @@ kerneldiagmatrix!(A3, k, m1) @test A3 == kerneldiagmatrix(k, m1) - @test repr(k) == "Piecewise Polynomial Kernel (v = $(v), size(maha) = $(size(maha)))" - @test_throws ErrorException PiecewisePolynomialKernel{4}(maha) + + @test repr(k) == "Piecewise Polynomial Kernel (v = $(v), size(maha) = $(size(maha)))" + # test_ADs(maha-> PiecewisePolynomialKernel(v=2, maha = maha), maha) + @test_broken "Nothing passes (problem with Mahalanobis distance in Distances)" end diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 900378f52..9d4319ce3 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -12,6 +12,7 @@ @test metric(LinearKernel()) == KernelFunctions.DotProduct() @test metric(LinearKernel(c=2.0)) == KernelFunctions.DotProduct() @test repr(k) == "Linear Kernel (c = 0.0)" + test_ADs(x->LinearKernel(c=x[1]), [c]) end @testset "PolynomialKernel" begin k = PolynomialKernel() @@ -24,5 +25,7 @@ @test metric(PolynomialKernel()) == KernelFunctions.DotProduct() @test metric(PolynomialKernel(d=3.0)) == KernelFunctions.DotProduct() @test metric(PolynomialKernel(d=3.0,c=2.0)) == KernelFunctions.DotProduct() + # test_ADs(x->PolynomialKernel(d=x[1], c=x[2]),[2.0, c]) + @test_broken "All, because of the power" end end diff --git a/test/basekernels/rationalquad.jl b/test/basekernels/rationalquad.jl index 4ec26cf13..47839f407 100644 --- a/test/basekernels/rationalquad.jl +++ b/test/basekernels/rationalquad.jl @@ -13,6 +13,7 @@ @test metric(RationalQuadraticKernel()) == SqEuclidean() @test metric(RationalQuadraticKernel(α=2.0)) == SqEuclidean() @test repr(k) == "Rational Quadratic Kernel (α = $(α))" + test_ADs(x->RationalQuadraticKernel(alpha=x[1]),[α]) end @testset "GammaRationalQuadraticKernel" begin k = GammaRationalQuadraticKernel() @@ -23,9 +24,11 @@ @test GammaRationalQuadraticKernel(alpha=a).α == [a] @test repr(k) == "Gamma Rational Quadratic Kernel (α = 2.0, γ = 2.0)" #Coherence test - @test kappa(GammaRationalQuadraticKernel(α=a,γ=1.0),x) ≈ kappa(RationalQuadraticKernel(α=a),x) + @test kappa(GammaRationalQuadraticKernel(α=a, γ=1.0), x) ≈ kappa(RationalQuadraticKernel(α=a), x) @test metric(GammaRationalQuadraticKernel()) == SqEuclidean() @test metric(GammaRationalQuadraticKernel(γ=2.0)) == SqEuclidean() - @test metric(GammaRationalQuadraticKernel(γ=2.0,α=3.0)) == SqEuclidean() + @test metric(GammaRationalQuadraticKernel(γ=2.0, α=3.0)) == SqEuclidean() + # test_ADs(x->GammaRationalQuadraticKernel(α=x[1], γ=x[2]), [a, 2.0]) + @test_broken "All (problem with power operation)" end end diff --git a/test/basekernels/sm.jl b/test/basekernels/sm.jl index a8e0a5768..daef2bd62 100644 --- a/test/basekernels/sm.jl +++ b/test/basekernels/sm.jl @@ -21,4 +21,6 @@ @test_throws DimensionMismatch spectral_mixture_kernel(rand(5) ,rand(4,3), rand(4,3)) @test_throws DimensionMismatch spectral_mixture_kernel(rand(3) ,rand(4,3), rand(5,3)) @test_throws DimensionMismatch spectral_mixture_product_kernel(rand(5,3) ,rand(4,3), rand(5,3)) + # test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5]) + @test_broken "No tests passing (BaseKernel)" end diff --git a/test/basekernels/wiener.jl b/test/basekernels/wiener.jl index 3b628fc65..624837b8c 100644 --- a/test/basekernels/wiener.jl +++ b/test/basekernels/wiener.jl @@ -50,4 +50,7 @@ @test kernelmatrix(k1, x1*ones(1,1), x2*ones(1,1))[1] ≈ k1(x1, x2) atol=1e-5 @test kernelmatrix(k2, x1*ones(1,1), x2*ones(1,1))[1] ≈ k2(x1, x2) atol=1e-5 @test kernelmatrix(k3, x1*ones(1,1), x2*ones(1,1))[1] ≈ k3(x1, x2) atol=1e-5 + + # test_ADs(()->WienerKernel(i=1)) + @test_broken "No tests passing" end diff --git a/test/kernels/kernelproduct.jl b/test/kernels/kernelproduct.jl index 00d5676d0..d39e81943 100644 --- a/test/kernels/kernelproduct.jl +++ b/test/kernels/kernelproduct.jl @@ -47,4 +47,6 @@ @test kerneldiagmatrix!(tmp_diag, k, x) ≈ kerneldiagmatrix(k, x) end end + test_ADs(x->SqExponentialKernel() * LinearKernel(c= x[1]), rand(1), ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Zygote issue" end diff --git a/test/kernels/kernelsum.jl b/test/kernels/kernelsum.jl index 310f43d00..6647fa466 100644 --- a/test/kernels/kernelsum.jl +++ b/test/kernels/kernelsum.jl @@ -53,4 +53,6 @@ @test kerneldiagmatrix!(tmp_diag, k, x) ≈ kerneldiagmatrix(k, x) end end + test_ADs(x->KernelSum([SqExponentialKernel(),LinearKernel(c= x[1])], x[2:3]), rand(3), ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Zygote failing because of mutating array" end diff --git a/test/kernels/scaledkernel.jl b/test/kernels/scaledkernel.jl index a5bf8998e..38e6593c3 100644 --- a/test/kernels/scaledkernel.jl +++ b/test/kernels/scaledkernel.jl @@ -40,4 +40,5 @@ @test_broken kerneldiagmatrix!(tmp_diag, ks, x) ≈ kerneldiagmatrix(ks, x) end end + test_ADs(x->x[1] * SqExponentialKernel(), rand(1)) end diff --git a/test/kernels/tensorproduct.jl b/test/kernels/tensorproduct.jl index 8ce9d5f72..1b016a68b 100644 --- a/test/kernels/tensorproduct.jl +++ b/test/kernels/tensorproduct.jl @@ -110,4 +110,5 @@ end end end + test_ADs(()->TensorProduct(SqExponentialKernel(), LinearKernel()), dims = [2, 2]) # ADs = [:ForwardDiff, :ReverseDiff]) end diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index cabbe0008..cf49dde2d 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -47,4 +47,5 @@ @test kerneldiagmatrix!(tmp_diag, kt, x) ≈ kerneldiagmatrix(kt, x) end end + test_ADs(x->transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff]) end diff --git a/test/runtests.jl b/test/runtests.jl index 0ff326256..d0ea3e3c5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,16 +1,15 @@ using KernelFunctions using Distances -using FiniteDifferences -using Flux using Kronecker using LinearAlgebra using PDMats using Random using SpecialFunctions using Test -using Zygote +using Flux +import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences -using KernelFunctions: metric, kappa +using KernelFunctions: metric, kappa, ColVecs, RowVecs # Writing tests: # 1. The file structure of the test should match precisely the file structure of src. @@ -41,16 +40,19 @@ using KernelFunctions: metric, kappa # disable tests by simply commenting them out, and makes it very clear which tests are not # currently being run. # 10. If utility files are required. +@info "Packages Loaded" @testset "KernelFunctions" begin include("utils.jl") + include("utils_AD.jl") @testset "distances" begin include(joinpath("distances", "dotproduct.jl")) include(joinpath("distances", "delta.jl")) include(joinpath("distances", "sinus.jl")) end + @info "Ran tests on Distances" @testset "transform" begin include(joinpath("transform", "transform.jl")) @@ -61,6 +63,7 @@ using KernelFunctions: metric, kappa include(joinpath("transform", "selecttransform.jl")) include(joinpath("transform", "chaintransform.jl")) end + @info "Ran tests on Transform" @testset "basekernels" begin include(joinpath("basekernels", "constant.jl")) @@ -79,6 +82,7 @@ using KernelFunctions: metric, kappa include(joinpath("basekernels", "sm.jl")) include(joinpath("basekernels", "wiener.jl")) end + @info "Ran tests on BaseKernel" @testset "kernels" begin include(joinpath("kernels", "kernelproduct.jl")) @@ -91,12 +95,14 @@ using KernelFunctions: metric, kappa # helpful these are. include(joinpath("kernels", "custom.jl")) end + @info "Ran tests on Kernel" @testset "matrix" begin include(joinpath("matrix", "kernelmatrix.jl")) include(joinpath("matrix", "kernelkroneckermat.jl")) include(joinpath("matrix", "kernelpdmat.jl")) end + @info "Ran tests on matrix" @testset "approximations" begin include(joinpath("approximations", "nystrom.jl")) @@ -106,9 +112,3 @@ using KernelFunctions: metric, kappa include("zygote_adjoints.jl") include("trainable.jl") end - -# These are legacy tests that I'm not getting rid of, as they appear to be useful, but -# weren't enabled on master at the time of refactoring the tests. They will need to be -# restored at some point. -# include("utils_AD.jl") -# include("test_AD.jl") diff --git a/test/test_AD.jl b/test/test_AD.jl deleted file mode 100644 index 9ee6e8566..000000000 --- a/test/test_AD.jl +++ /dev/null @@ -1,119 +0,0 @@ -using KernelFunctions -using Zygote, ForwardDiff -using Test, LinearAlgebra -using FiniteDifferences - -dims = [10,5] - -A = rand(dims...) -B = rand(dims...) -K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])] -kernels_noparams = [:SqExponentialKernel,:ExponentialKernel,:GammaExponentialKernel, - :MaternKernel,:Matern32Kernel,:Matern52Kernel, - :LinearKernel,:PolynomialKernel, - :RationalQuadraticKernel,:GammaRationalQuadraticKernel, - :ExponentiatedKernel] -l = 2.0 -ds = [0.0,3.0] -vl = l*ones(dims[1]) -testfunction(k,A,B) = det(kernelmatrix(k,A,B)) -testfunction(k,A) = det(kernelmatrix(k,A)) -ADs = [:Zygote,:ForwardDiff] - -## Test kappa functions -@testset "Kappa functions" begin - for AD in ADs - @testset "$AD" begin - for k in kernels_noparams - for d in ds - @eval begin @test kappa_AD(Val(Symbol($AD)),$k(),$d) ≈ kappa_fdm($k(),$d) atol=1e-8 end - end - end - # Linear -> C - # Polynomial -> C,D - # Gamma (etc) -> gamma - # - end - end -end - -@testset "Transform Operations" begin - for AD in ADs - @testset "$AD" begin - @eval begin - # Scale Transform - transform_AD(Val(Symbol($AD)),ScaleTransform(l),A) - # ARD Transform - transform_AD(Val(Symbol($AD)),ARDTransform(vl),A) - # Linear transform - transform_AD(Val(Symbol($AD)), LinearTransform(rand(2,10)),A) - # Chain Transform - # transform_AD(Val(Symbol($AD)), LinearTransform, A) - end - end - end -end - -##TODO Eventually store real results in file -@testset "Zygote Automatic Differentiation test" begin - @testset "ARD" begin - for k in kernels - @testset "$k" begin - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),vl)[1], ForwardDiff.gradient(x->testfunction(k(x),A,B),vl))) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x,B),A))) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),vl)[1],ForwardDiff.gradient(x->testfunction(k(x),A),vl))) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x),A))) - end - end - end - @testset "ISO" begin - for k in kernels - @testset "$k" begin - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])[1])) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(l),x,B),A))) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l]))) - @test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x),A)[1],ForwardDiff.gradient(x->testfunction(k(l[1]),x),A))) - end - end - end -end - -@testset "ForwardDiff AutomaticDifferentation test" begin - @testset "ARD" begin - for k in kernels - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A,B),vl) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x,B),A) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A),vl) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x),A) - end - end - @testset "ISO" begin - for k in kernels - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l]) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(l),x,B),A) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l]) - @test_nowarn ForwardDiff.gradient(x->testfunction(k(l[1]),x),A) - end - end -end - - -@testset "Tracker AutomaticDifferentation test" begin - @testset "ARD" begin - for k in kernels - @test_broken all(Tracker.gradient(x->testfunction(k(x),A,B),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)) - @test_broken all(Tracker.gradient(x->testfunction(k(vl),x,B),A)[1] .≈ ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)) - @test_broken all(Tracker.gradient(x->testfunction(k(x),A),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A),vl)) - @test_broken all.(Tracker.gradient(x->testfunction(k(vl),x),A) .≈ ForwardDiff.gradient(x->testfunction(k(vl),x),A)) - end - end - @testset "ISO" begin - for k in kernels - @test_broken Tracker.gradient(x->testfunction(k(x[1]),A,B),[l]) - @test_broken Tracker.gradient(x->testfunction(k(l),x,B),A) - @test_broken Tracker.gradient(x->testfunction(k(x[1]),A),[l]) - @test_broken Tracker.gradient(x->testfunction(k(l),x),A) - - end - end -end diff --git a/test/transform/ardtransform.jl b/test/transform/ardtransform.jl index 4bd10a6dc..e05f50968 100644 --- a/test/transform/ardtransform.jl +++ b/test/transform/ardtransform.jl @@ -41,4 +41,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3))) @test repr(t) == "ARD Transform (dims: $D)" + test_ADs(x->transform(SEKernel(), exp.(x)), randn(rng, 3)) end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index a13883e81..55dd13b74 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -22,8 +22,5 @@ # Verify printing works as expected. @test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)" + test_ADs(x->transform(SEKernel(), ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), randn(rng, 4)) end - - -Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t)) -Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms)) diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index 17ddbdb4f..f8441c38c 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -26,4 +26,7 @@ end @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" + f(a, x) = sin.(a .* x) + test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3), ADs = [:ForwardDiff, :ReverseDiff]) + @test_broken "Zygote is failing" end diff --git a/test/transform/lineartransform.jl b/test/transform/lineartransform.jl index ff65e20b4..46342bc73 100644 --- a/test/transform/lineartransform.jl +++ b/test/transform/lineartransform.jl @@ -41,4 +41,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout))) @test repr(t) == "Linear transform (size(A) = ($Dout, $Din))" + test_ADs(x->transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3)) end diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index d9aece310..c97d937f1 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -18,4 +18,5 @@ @test t.s == [s2] @test isequal(ScaleTransform(s), ScaleTransform(s)) @test repr(t) == "Scale Transform (s = $(s2))" + test_ADs(x->transform(SEKernel(), exp(x[1])), randn(rng, 1)) end diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index 1781356b1..a34a9ab3d 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -18,4 +18,5 @@ @test t.select == select2 @test repr(t) == "Select Transform (dims: $(select2))" + test_ADs(()->transform(SEKernel(), SelectTransform([1,2]))) end diff --git a/test/transform/transform.jl b/test/transform/transform.jl index 0b79dcad5..6ce7c46bf 100644 --- a/test/transform/transform.jl +++ b/test/transform/transform.jl @@ -7,4 +7,5 @@ @test IdentityTransform()(x) == x @test map(IdentityTransform(), x) == x end + test_ADs(()->transform(SEKernel(), IdentityTransform())) end diff --git a/test/utils_AD.jl b/test/utils_AD.jl index 77647e6d1..1354485f9 100644 --- a/test/utils_AD.jl +++ b/test/utils_AD.jl @@ -1,39 +1,148 @@ -allapprox(x,y,tol=1e-8) = all(isapprox.(x,y,atol=tol)) -FDM = central_fdm(5,1) +const FDM = FiniteDifferences.central_fdm(5, 1) -function kappa_AD(::Val{:Zygote},k::Kernel,d::Real) - first(Zygote.gradient(x->kappa(k,x),d)) +gradient(f, s::Symbol, args) = gradient(f, Val(s), args) + +function gradient(f, ::Val{:Zygote}, args) + g = first(Zygote.gradient(f, args)) + if isnothing(g) + if args isa AbstractArray{<:Real} + return zeros(size(args)) # To respect the same output as other ADs + else + return zeros.(size.(args)) + end + else + return g + end end -function kappa_AD(::Val{:ForwardDiff},k::Kernel,d::Real) - first(ForwardDiff.gradient(x->kappa(k,first(x)),[d])) +function gradient(f, ::Val{:ForwardDiff}, args) + ForwardDiff.gradient(f, args) end -function kappa_fdm(k::Kernel,d::Real) - first(FiniteDifferences.grad(FDM,x->kappa(k,x),d)) +function gradient(f, ::Val{:ReverseDiff}, args) + ReverseDiff.gradient(f, args) end +function gradient(f, ::Val{:FiniteDiff}, args) + first(FiniteDifferences.grad(FDM, f, args)) +end -function transform_AD(::Val{:Zygote},t::Transform,A) - ps = KernelFunctions.params(t) - @test allapprox(first(Zygote.gradient(p->transform_with_duplicate(p,t,A),ps)), - first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps))) - @test allapprox(first(Zygote.gradient(X->sum(transform(t,X,2)),A)), - first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A))) +function compare_gradient(f, AD::Symbol, args) + grad_AD = gradient(f, AD, args) + grad_FD = gradient(f, :FiniteDiff, args) + @test grad_AD ≈ grad_FD atol=1e-8 rtol=1e-5 end -function transform_AD(::Val{:ForwardDiff},t::Transform,A) - ps = KernelFunctions.params(t) - if t isa ScaleTransform - @test allapprox(first(ForwardDiff.gradient(p->transform_with_duplicate(first(p),t,A),[ps])), - first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps))) +testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim)) +testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim)) + +function test_ADs(kernelfunction, args = nothing; ADs = [:Zygote, :ForwardDiff, :ReverseDiff], dims = [3, 3]) + test_fd = test_FiniteDiff(kernelfunction, args, dims) + if !test_fd.anynonpass + for AD in ADs + test_AD(AD, kernelfunction, args, dims) + end + end +end + +function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3]) + # Init arguments : + k = if args === nothing + kernelfunction() else - @test allapprox(ForwardDiff.gradient(p->transform_with_duplicate(p,t,A),ps), - first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps))) + kernelfunction(args) + end + rng = MersenneTwister(42) + @testset "FiniteDifferences" begin + if k isa SimpleKernel + for d in log.([eps(), rand(rng)]) + @test_nowarn gradient(:FiniteDiff, [d]) do x + kappa(k, exp(first(x))) + end + end + end + ## Testing Kernel Functions + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + @test_nowarn gradient(:FiniteDiff, x) do x + k(x, y) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + kernelfunction(p)(x, y) + end + end + ## Testing Kernel Matrices + A = rand(rng, dims...) + B = rand(rng, dims...) + for dim in 1:2 + @test_nowarn gradient(:FiniteDiff, A) do a + testfunction(k, a, dim) + end + @test_nowarn gradient(:FiniteDiff , A) do a + testfunction(k, a, B, dim) + end + @test_nowarn gradient(:FiniteDiff, B) do b + testfunction(k, A, b, dim) + end + if !(args === nothing) + @test_nowarn gradient(:FiniteDiff, args) do p + testfunction(kernelfunction(p), A, B, dim) + end + end + end end - @test allapprox(ForwardDiff.gradient(X->sum(transform(t,X,2)),A), - first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A))) end -transform_with_duplicate(p,t,A) = sum(transform(KernelFunctions.duplicate(t,p),A,2)) +function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) + @testset "$(AD)" begin + # Test kappa function + k = if args === nothing + kernelfunction() + else + kernelfunction(args) + end + rng = MersenneTwister(42) + if k isa SimpleKernel + for d in log.([eps(), rand(rng)]) + compare_gradient(AD, [d]) do x + kappa(k, exp(x[1])) + end + end + end + # Testing kernel evaluations + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + compare_gradient(AD, x) do x + k(x, y) + end + compare_gradient(AD, y) do y + k(x, y) + end + if !(args === nothing) + compare_gradient(AD, args) do p + kernelfunction(p)(x,y) + end + end + # Testing kernel matrices + A = rand(rng, dims...) + B = rand(rng, dims...) + for dim in 1:2 + compare_gradient(AD, A) do a + testfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testfunction(k, A, b, dim) + end + if !(args === nothing) + compare_gradient(AD, args) do p + testfunction(kernelfunction(p), A, dim) + end + end + end + end +end diff --git a/test/zygote_adjoints.jl b/test/zygote_adjoints.jl index e81cb7097..5e9447b37 100644 --- a/test/zygote_adjoints.jl +++ b/test/zygote_adjoints.jl @@ -3,18 +3,44 @@ rng = MersenneTwister(123456) x = rand(rng, 5) y = rand(rng, 5) + r = rand(rng, 5) - gzeucl = first(Zygote.gradient(xy->evaluate(Euclidean(),xy[1],xy[2]),[x,y])) - gzsqeucl = first(Zygote.gradient(xy->evaluate(SqEuclidean(),xy[1],xy[2]),[x,y])) - gzdotprod = first(Zygote.gradient(xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),[x,y])) + gzeucl = gradient(:Zygote, [x,y]) do xy + evaluate(Euclidean(), xy[1], xy[2]) + end + gzsqeucl = gradient(:Zygote, [x,y]) do xy + evaluate(SqEuclidean(), xy[1], xy[2]) + end + gzdotprod = gradient(:Zygote, [x,y]) do xy + evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) + end + gzdelta = gradient(:Zygote, [x,y]) do xy + evaluate(KernelFunctions.Delta(), xy[1], xy[2]) + end + gzsinus = gradient(:Zygote, [x,y]) do xy + evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) + end - FDM = central_fdm(5,1) + gfeucl = gradient(:FiniteDiff, [x,y]) do xy + evaluate(Euclidean(), xy[1], xy[2]) + end + gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy + evaluate(SqEuclidean(), xy[1], xy[2]) + end + gfdotprod = gradient(:FiniteDiff, [x,y]) do xy + evaluate(KernelFunctions.DotProduct(), xy[1], xy[2]) + end + gfdelta = gradient(:FiniteDiff, [x,y]) do xy + evaluate(KernelFunctions.Delta(), xy[1], xy[2]) + end + gfsinus = gradient(:FiniteDiff, [x,y]) do xy + evaluate(KernelFunctions.Sinus(r), xy[1], xy[2]) + end - gfeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(Euclidean(),xy[1],xy[2]),(x,y)))) - gfsqeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(SqEuclidean(),xy[1],xy[2]),(x,y)))) - gfdotprod =collect(first(FiniteDifferences.grad(FDM,xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),(x,y)))) @test all(gzeucl .≈ gfeucl) @test all(gzsqeucl .≈ gfsqeucl) @test all(gzdotprod .≈ gfdotprod) + @test all(gzdelta .≈ gfdelta) + @test all(gzsinus .≈ gfsinus) end