diff --git a/Project.toml b/Project.toml index df4ed07..a26bb62 100644 --- a/Project.toml +++ b/Project.toml @@ -6,18 +6,21 @@ version = "0.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" +ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesCore = "1.3.0" CoordinateTransformations = "0.6.1" +ImageBase = "0.1.5" ImageCore = "0.9" ImageTransformations = "0.8, 0.9" Interpolations = "0.13.4" diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..b07ecd4 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -7,14 +7,17 @@ using ImageCore, Interpolations, ChainRulesCore, LinearAlgebra, - Rotations + Rotations, + ImageBase using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify +export colorify, channelify, fdiff include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") +include("ImageBase.jl/fdiff.jl") +include("ImageBase.jl/statistics.jl") end diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..d89c90a --- /dev/null +++ b/src/ImageBase.jl/fdiff.jl @@ -0,0 +1,9 @@ +# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images +# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case +@adjoint function fdiff(A::AbstractArray; kwargs...) + y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) + function pullback(Δ) + return (fill(Δ, size(A)),) + end + return (y, pullback) +end \ No newline at end of file diff --git a/src/ImageBase.jl/statistics.jl b/src/ImageBase.jl/statistics.jl new file mode 100644 index 0000000..10cc61d --- /dev/null +++ b/src/ImageBase.jl/statistics.jl @@ -0,0 +1,35 @@ +@adjoint function sumfinite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.sumfinite(identity, A; kwargs...) + function pullback(Δ) + return (fill(Δ,size(A)),) + end + return (y, pullback) +end + +@adjoint function meanfinite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.meanfinite(identity, A; kwargs...) + function pullback(Δ) + return (fill(Δ / length(A),size(A)),) + end + return (y, pullback) +end + +@adjoint function maximum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.maximum_finite(identity, A; kwargs...) + final = zeros(Float64, size(A)) + function pullback(Δ) + final[last(findall(x -> x == y, A))] = Δ + return (final,) + end + return (y, pullback) +end + +@adjoint function minimum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N} + y = ImageBase.minimum_finite(identity, A; kwargs...) + final = zeros(Float64, size(A)) + function pullback(Δ) + final[first(findall(x -> x == y, A))] = Δ + return (final,) + end + return (y, pullback) +end \ No newline at end of file diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..229fa68 --- /dev/null +++ b/test/ImageBase.jl/fdiff.jl @@ -0,0 +1,62 @@ +using ImageBase.FiniteDiff: fdiff, fdiff! +@testset "fdiff" begin + # Base.diff doesn't promote integer to float + @test ImageBase.FiniteDiff.maybe_floattype(Int) == Int + @test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32 + @test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32} + @testset "NumericalTests" begin + a = reshape(collect(1:9), 3, 3) + b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2] + b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6] + b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1] + b_bd_2 = [-6 3 3; -6 3 3; -6 3 3] + out = similar(a) + + @test fdiff(a, dims=1) == b_fd_1 + @test fdiff(a, dims=2) == b_fd_2 + @test fdiff(a, dims=1, rev=true) == b_bd_1 + @test fdiff(a, dims=2, rev=true) == b_bd_2 + fdiff!(out, a, dims=1) + @test out == b_fd_1 + fdiff!(out, a, dims=2) + @test out == b_fd_2 + fdiff!(out, a, dims=1, rev=true) + @test out == b_bd_1 + fdiff!(out, a, dims=2, rev=true) + @test out == b_bd_2 + end + @testset "Differentiability" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81] + @testset "Testing basic fdiff" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with rev" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + end + @testset "Testing fdiff with boundary condition" begin + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2)) + end + end +end \ No newline at end of file diff --git a/test/ImageBase.jl/statistics.jl b/test/ImageBase.jl/statistics.jl new file mode 100644 index 0000000..06f14e9 --- /dev/null +++ b/test/ImageBase.jl/statistics.jl @@ -0,0 +1,90 @@ +@testset "Statistics" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81] + a_fd_3 = rand(10, 10) + a_fd_4 = randn(6, 4) + b_fd_1 = zeros(Float64, size(a_fd_1)) + b_fd_2 = zeros(Float64, size(a_fd_2)) + b_fd_3 = zeros(Float64, size(a_fd_3)) + b_fd_4 = zeros(Float64, size(a_fd_4)) + e_fd_1 = zeros(Float64, size(a_fd_1)) + e_fd_2 = zeros(Float64, size(a_fd_2)) + e_fd_3 = zeros(Float64, size(a_fd_3)) + e_fd_4 = zeros(Float64, size(a_fd_4)) + c_fd_1 = minimum_finite(a_fd_1) + c_fd_2 = minimum_finite(a_fd_2) + c_fd_3 = minimum_finite(a_fd_3) + c_fd_4 = minimum_finite(a_fd_4) + d_fd_1 = maximum_finite(a_fd_1) + d_fd_2 = maximum_finite(a_fd_2) + d_fd_3 = maximum_finite(a_fd_3) + d_fd_4 = maximum_finite(a_fd_4) + @testset "NumericalTests" begin + @testset "Testing sumfinite" begin + @test sumfinite(a_fd_1) == sum(a_fd_1) + @test sumfinite(a_fd_2) == sum(a_fd_2) + @test sumfinite(a_fd_3) == sum(a_fd_3) + @test sumfinite(a_fd_4) == sum(a_fd_4) + @test sumfinite(a_fd_1) == 137 + @test sumfinite(a_fd_2) == 288 + end + @testset "Testing meanfinite" begin + @test meanfinite(a_fd_1) ≈ mean(a_fd_1) + @test meanfinite(a_fd_2) ≈ mean(a_fd_2) + @test meanfinite(a_fd_3) ≈ mean(a_fd_3) + @test meanfinite(a_fd_4) ≈ mean(a_fd_4) + @test meanfinite(a_fd_1) ≈ 15.222222222222221 + @test meanfinite(a_fd_2) ≈ 24.0 + end + @testset "Testing minimum_finite" begin + @test minimum_finite(a_fd_1) == minimum(a_fd_1) + @test minimum_finite(a_fd_2) == minimum(a_fd_2) + @test minimum_finite(a_fd_3) == minimum(a_fd_3) + @test minimum_finite(a_fd_4) == minimum(a_fd_4) + @test minimum_finite(a_fd_1) == 2 + @test minimum_finite(a_fd_2) == 3 + end + @testset "Testing maximum_finite" begin + @test maximum_finite(a_fd_1) == maximum(a_fd_1) + @test maximum_finite(a_fd_2) == maximum(a_fd_2) + @test maximum_finite(a_fd_3) == maximum(a_fd_3) + @test maximum_finite(a_fd_4) == maximum(a_fd_4) + @test maximum_finite(a_fd_1) == 64 + @test maximum_finite(a_fd_2) == 81 + end + end + @testset "Testing Differentiability" begin + @testset "Testing sumfinite" begin + @test Zygote.gradient(sumfinite, a_fd_1)[1] == ones(Float64, size(a_fd_1)) + @test Zygote.gradient(sumfinite, a_fd_2)[1] == ones(Float64, size(a_fd_2)) + @test Zygote.gradient(sumfinite, a_fd_3)[1] == ones(Float64, size(a_fd_3)) + @test Zygote.gradient(sumfinite, a_fd_4)[1] == ones(Float64, size(a_fd_4)) + end + @testset "Testing meanfinite" begin + @test Zygote.gradient(meanfinite, a_fd_1)[1] == fill((1 / length(a_fd_1)), size(a_fd_1)) + @test Zygote.gradient(meanfinite, a_fd_2)[1] == fill((1 / length(a_fd_2)), size(a_fd_2)) + @test Zygote.gradient(meanfinite, a_fd_3)[1] == fill(1 / length(a_fd_3), size(a_fd_3)) + @test Zygote.gradient(meanfinite, a_fd_4)[1] == fill(1 / length(a_fd_4), size(a_fd_4)) + end + @testset "Testing minimum_finite" begin + b_fd_1[first(findall(x -> x == c_fd_1, a_fd_1))] = 1 + b_fd_2[first(findall(x -> x == c_fd_2, a_fd_2))] = 1 + b_fd_3[first(findall(x -> x == c_fd_3, a_fd_3))] = 1 + b_fd_4[first(findall(x -> x == c_fd_4, a_fd_4))] = 1 + @test Zygote.gradient(minimum_finite, a_fd_1)[1] == b_fd_1 + @test Zygote.gradient(minimum_finite, a_fd_2)[1] == b_fd_2 + @test Zygote.gradient(minimum_finite, a_fd_3)[1] == b_fd_3 + @test Zygote.gradient(minimum_finite, a_fd_4)[1] == b_fd_4 + end + @testset "Testing maximum_finite" begin + e_fd_1[last(findall(x -> x == d_fd_1, a_fd_1))] = 1 + e_fd_2[last(findall(x -> x == d_fd_2, a_fd_2))] = 1 + e_fd_3[last(findall(x -> x == d_fd_3, a_fd_3))] = 1 + e_fd_4[last(findall(x -> x == d_fd_4, a_fd_4))] = 1 + @test Zygote.gradient(maximum_finite, a_fd_1)[1] == e_fd_1 + @test Zygote.gradient(maximum_finite, a_fd_2)[1] == e_fd_2 + @test Zygote.gradient(maximum_finite, a_fd_3)[1] == e_fd_3 + @test Zygote.gradient(maximum_finite, a_fd_4)[1] == e_fd_4 + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 75e60b8..34173dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,9 @@ using Test, FiniteDifferences, ChainRulesCore, CoordinateTransformations, - Rotations + Rotations, + ImageBase, + Statistics @testset "DiffImages" begin @info "Testing Colorspace modules" @@ -23,4 +25,11 @@ using Test, @testset "Warps" begin include("geometry/warp.jl") end + @info "Testing ImageBase modules" + @testset "FiniteDifferences" begin + include("ImageBase.jl/fdiff.jl") + end + @testset "Statistics" begin + include("ImageBase.jl/statistics.jl") + end end