diff --git a/src/chainrules.jl b/src/chainrules.jl index 4b69a827f..dde352f8c 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -118,7 +118,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) return val, evaluate_pullback end -## Reverse Rulse SqMahalanobis +## Reverse Rules SqMahalanobis function ChainRulesCore.rrule( dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector diff --git a/test/Project.toml b/test/Project.toml index 7afce657a..4c04f2c42 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AxisArrays = "0.4.3" Compat = "3" -Distances = "= 0.10.0, = 0.10.1, = 0.10.2, = 0.10.3, = 0.10.4" +Distances = "0.10" Documenter = "0.25, 0.26, 0.27" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10" diff --git a/test/chainrules.jl b/test/chainrules.jl index 51a545ba1..03c2c3b1f 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -21,7 +21,11 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end - compare_gradient(:Zygote, [Q, x, y]) do xy - SqMahalanobis(xy[1])(xy[2], xy[3]) + if VERSION < v"1.6" + @test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6" + else + compare_gradient(:Zygote, [Q, x, y]) do Qxy + SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) + end end end