Skip to content
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ version = "0.1.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1.3.0"
CoordinateTransformations = "0.6.1"
ImageBase = "0.1.5"
ImageCore = "0.9"
ImageTransformations = "0.8, 0.9"
Interpolations = "0.13.4"
Expand Down
7 changes: 5 additions & 2 deletions src/DiffImages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ using ImageCore,
Interpolations,
ChainRulesCore,
LinearAlgebra,
Rotations
Rotations,
ImageBase

using Zygote: @adjoint
using ChainRulesCore: NoTangent

export colorify, channelify
export colorify, channelify, fdiff
include("colors/conversions.jl")
include("geometry/warp.jl")
include("geometry/adjoints.jl")
include("ImageBase.jl/fdiff.jl")
include("ImageBase.jl/statistics.jl")

end
9 changes: 9 additions & 0 deletions src/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images
# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case
@adjoint function fdiff(A::AbstractArray; kwargs...)
y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...)
function pullback(Δ)
return (fill(Δ, size(A)),)
end
return (y, pullback)
end
35 changes: 35 additions & 0 deletions src/ImageBase.jl/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@adjoint function sumfinite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.sumfinite(identity, A; kwargs...)
function pullback(Δ)
return (fill(Δ,size(A)),)
end
return (y, pullback)
end

@adjoint function meanfinite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.meanfinite(identity, A; kwargs...)
function pullback(Δ)
return (fill(Δ / length(A),size(A)),)
end
return (y, pullback)
end

@adjoint function maximum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.maximum_finite(identity, A; kwargs...)
final = zeros(Float64, size(A))
function pullback(Δ)
final[last(findall(x -> x == y, A))] = Δ
return (final,)
end
return (y, pullback)
end

@adjoint function minimum_finite(A::AbstractArray{T,N}; kwargs...) where {T,N}
y = ImageBase.minimum_finite(identity, A; kwargs...)
final = zeros(Float64, size(A))
function pullback(Δ)
final[first(findall(x -> x == y, A))] = Δ
return (final,)
end
return (y, pullback)
end
62 changes: 62 additions & 0 deletions test/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using ImageBase.FiniteDiff: fdiff, fdiff!
@testset "fdiff" begin
# Base.diff doesn't promote integer to float
@test ImageBase.FiniteDiff.maybe_floattype(Int) == Int
@test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32
@test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32}
@testset "NumericalTests" begin
a = reshape(collect(1:9), 3, 3)
b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2]
b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6]
b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1]
b_bd_2 = [-6 3 3; -6 3 3; -6 3 3]
out = similar(a)

@test fdiff(a, dims=1) == b_fd_1
@test fdiff(a, dims=2) == b_fd_2
@test fdiff(a, dims=1, rev=true) == b_bd_1
@test fdiff(a, dims=2, rev=true) == b_bd_2
fdiff!(out, a, dims=1)
@test out == b_fd_1
fdiff!(out, a, dims=2)
@test out == b_fd_2
fdiff!(out, a, dims=1, rev=true)
@test out == b_bd_1
fdiff!(out, a, dims=2, rev=true)
@test out == b_bd_2
end
@testset "Differentiability" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
@testset "Testing basic fdiff" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with rev" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with boundary condition" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
end
end
end
90 changes: 90 additions & 0 deletions test/ImageBase.jl/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
@testset "Statistics" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
a_fd_3 = rand(10, 10)
a_fd_4 = randn(6, 4)
b_fd_1 = zeros(Float64, size(a_fd_1))
b_fd_2 = zeros(Float64, size(a_fd_2))
b_fd_3 = zeros(Float64, size(a_fd_3))
b_fd_4 = zeros(Float64, size(a_fd_4))
e_fd_1 = zeros(Float64, size(a_fd_1))
e_fd_2 = zeros(Float64, size(a_fd_2))
e_fd_3 = zeros(Float64, size(a_fd_3))
e_fd_4 = zeros(Float64, size(a_fd_4))
c_fd_1 = minimum_finite(a_fd_1)
c_fd_2 = minimum_finite(a_fd_2)
c_fd_3 = minimum_finite(a_fd_3)
c_fd_4 = minimum_finite(a_fd_4)
d_fd_1 = maximum_finite(a_fd_1)
d_fd_2 = maximum_finite(a_fd_2)
d_fd_3 = maximum_finite(a_fd_3)
d_fd_4 = maximum_finite(a_fd_4)
@testset "NumericalTests" begin
@testset "Testing sumfinite" begin
@test sumfinite(a_fd_1) == sum(a_fd_1)
@test sumfinite(a_fd_2) == sum(a_fd_2)
@test sumfinite(a_fd_3) == sum(a_fd_3)
@test sumfinite(a_fd_4) == sum(a_fd_4)
@test sumfinite(a_fd_1) == 137
@test sumfinite(a_fd_2) == 288
end
@testset "Testing meanfinite" begin
@test meanfinite(a_fd_1) ≈ mean(a_fd_1)
@test meanfinite(a_fd_2) ≈ mean(a_fd_2)
@test meanfinite(a_fd_3) ≈ mean(a_fd_3)
@test meanfinite(a_fd_4) ≈ mean(a_fd_4)
@test meanfinite(a_fd_1) ≈ 15.222222222222221
@test meanfinite(a_fd_2) ≈ 24.0
end
@testset "Testing minimum_finite" begin
@test minimum_finite(a_fd_1) == minimum(a_fd_1)
@test minimum_finite(a_fd_2) == minimum(a_fd_2)
@test minimum_finite(a_fd_3) == minimum(a_fd_3)
@test minimum_finite(a_fd_4) == minimum(a_fd_4)
@test minimum_finite(a_fd_1) == 2
@test minimum_finite(a_fd_2) == 3
end
@testset "Testing maximum_finite" begin
@test maximum_finite(a_fd_1) == maximum(a_fd_1)
@test maximum_finite(a_fd_2) == maximum(a_fd_2)
@test maximum_finite(a_fd_3) == maximum(a_fd_3)
@test maximum_finite(a_fd_4) == maximum(a_fd_4)
@test maximum_finite(a_fd_1) == 64
@test maximum_finite(a_fd_2) == 81
end
end
@testset "Testing Differentiability" begin
@testset "Testing sumfinite" begin
@test Zygote.gradient(sumfinite, a_fd_1)[1] == ones(Float64, size(a_fd_1))
@test Zygote.gradient(sumfinite, a_fd_2)[1] == ones(Float64, size(a_fd_2))
@test Zygote.gradient(sumfinite, a_fd_3)[1] == ones(Float64, size(a_fd_3))
@test Zygote.gradient(sumfinite, a_fd_4)[1] == ones(Float64, size(a_fd_4))
end
@testset "Testing meanfinite" begin
@test Zygote.gradient(meanfinite, a_fd_1)[1] == fill((1 / length(a_fd_1)), size(a_fd_1))
@test Zygote.gradient(meanfinite, a_fd_2)[1] == fill((1 / length(a_fd_2)), size(a_fd_2))
@test Zygote.gradient(meanfinite, a_fd_3)[1] == fill(1 / length(a_fd_3), size(a_fd_3))
@test Zygote.gradient(meanfinite, a_fd_4)[1] == fill(1 / length(a_fd_4), size(a_fd_4))
end
@testset "Testing minimum_finite" begin
b_fd_1[first(findall(x -> x == c_fd_1, a_fd_1))] = 1
b_fd_2[first(findall(x -> x == c_fd_2, a_fd_2))] = 1
b_fd_3[first(findall(x -> x == c_fd_3, a_fd_3))] = 1
b_fd_4[first(findall(x -> x == c_fd_4, a_fd_4))] = 1
@test Zygote.gradient(minimum_finite, a_fd_1)[1] == b_fd_1
@test Zygote.gradient(minimum_finite, a_fd_2)[1] == b_fd_2
@test Zygote.gradient(minimum_finite, a_fd_3)[1] == b_fd_3
@test Zygote.gradient(minimum_finite, a_fd_4)[1] == b_fd_4
end
@testset "Testing maximum_finite" begin
e_fd_1[last(findall(x -> x == d_fd_1, a_fd_1))] = 1
e_fd_2[last(findall(x -> x == d_fd_2, a_fd_2))] = 1
e_fd_3[last(findall(x -> x == d_fd_3, a_fd_3))] = 1
e_fd_4[last(findall(x -> x == d_fd_4, a_fd_4))] = 1
@test Zygote.gradient(maximum_finite, a_fd_1)[1] == e_fd_1
@test Zygote.gradient(maximum_finite, a_fd_2)[1] == e_fd_2
@test Zygote.gradient(maximum_finite, a_fd_3)[1] == e_fd_3
@test Zygote.gradient(maximum_finite, a_fd_4)[1] == e_fd_4
end
end
end
11 changes: 10 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ using Test,
FiniteDifferences,
ChainRulesCore,
CoordinateTransformations,
Rotations
Rotations,
ImageBase,
Statistics

@testset "DiffImages" begin
@info "Testing Colorspace modules"
Expand All @@ -23,4 +25,11 @@ using Test,
@testset "Warps" begin
include("geometry/warp.jl")
end
@info "Testing ImageBase modules"
@testset "FiniteDifferences" begin
include("ImageBase.jl/fdiff.jl")
end
@testset "Statistics" begin
include("ImageBase.jl/statistics.jl")
end
end