diff --git a/Project.toml b/Project.toml index f99a94c..2f5cb01 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.9.2" +version = "1.9.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/check_result.jl b/src/check_result.jl index fee2943..0f192a4 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -40,12 +40,17 @@ for (T1, T2) in end test_approx(::AbstractZero, x, msg=""; kwargs...) = test_approx(zero(x), x, msg; kwargs...) -test_approx(::AbstractZero, x::AbstractArray{<:AbstractArray}, msg=""; kwargs...) = test_approx(map(zero, x), x, msg; kwargs...) test_approx(x, ::AbstractZero, msg=""; kwargs...) = test_approx(x, zero(x), msg; kwargs...) -test_approx(x::AbstractArray{<:AbstractArray}, ::AbstractZero, msg=""; kwargs...) = test_approx(x, map(zero, x), msg; kwargs...) test_approx(x::ZeroTangent, y::ZeroTangent, msg=""; kwargs...) = @test true test_approx(x::NoTangent, y::NoTangent, msg=""; kwargs...) = @test true +function test_approx(z::AbstractZero, x::AbstractArray{<:AbstractArray}, msg=""; kwargs...) + for el in x + test_approx(el, z, msg; kwargs...) + end +end +test_approx(x::AbstractArray{<:AbstractArray}, z::AbstractZero, msg=""; kwargs...) = test_approx(z, x, msg; kwargs...) + # remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 test_approx(x::NoTangent, y::Nothing, msg=""; kwargs...) = @test true test_approx(x::Nothing, y::NoTangent, msg=""; kwargs...) = @test true @@ -134,8 +139,8 @@ function test_approx(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T end test_approx(x, y::Tangent, msg=""; kwargs...) = test_approx(y, x, msg; kwargs...) -test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = all(==(NoTangent()), t) -test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = all(==(NoTangent()), t) +test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = @test all(==(NoTangent()), t) +test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = @test all(==(NoTangent()), t) # This catches comparisons of Tangents and Tuples/NamedTuple # and gives an error message complaining about that. the `@test` will definitely fail diff --git a/test/check_result.jl b/test/check_result.jl index a708990..d50dcbc 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -38,6 +38,8 @@ end test_approx([[1.0], [2.0]], [[1.0], [2.0]]) test_approx([[0.0], [0.0]], ZeroTangent()) test_approx(ZeroTangent(), [[0.0], [0.0]]) + test_approx(ZeroTangent(), [[0.0, 0.0], [[0.0, 0.0], [0.0, 0.0]]]) + test_approx([[0.0, 0.0], [[0.0, 0.0], [0.0, 0.0]]], NoTangent()) test_approx(Broadcast.broadcasted(identity, [1.0 2.0; 3.0 4.0]), [1.0 2.0; 3.0 4.0]) test_approx(@thunk(10 * 0.1 * [[1.0], [2.0]]), [[1.0], [2.0]]) @@ -112,6 +114,17 @@ end @test fails(() -> test_approx([[1.0], [2.0]], [[1.1], [2.0]])) @test fails(() -> test_approx([[0.0], [0.1]], ZeroTangent())) @test fails(() -> test_approx(ZeroTangent(), [[0.1], [0.0]])) + @test fails(() -> test_approx([[0.0], [0.0], [[0.0, 0.1], [0.0]]], ZeroTangent())) + @test fails(() -> test_approx(ZeroTangent(), [[0.0], [0.0], [[0.0, 0.1], [0.0]]])) + + @test fails(() -> test_approx( + Tangent{Tuple{Float64,Float64}}(NoTangent(), 0.1), + NoTangent(), + )) + @test fails(() -> test_approx( + NoTangent(), + Tangent{Tuple{Float64,Float64}}(NoTangent(), 0.1), + )) @test fails(() -> test_approx(@thunk(10 * [[1.0], [2.0]]), [[1.0], [2.0]]))