Skip to content
Open
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -24,6 +25,7 @@ Interpolations = "0.13.4"
Rotations = "1.0.2"
StaticArrays = "1.2"
Zygote = "0.6.17"
ImageBase = "0.1.5"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand Down
6 changes: 4 additions & 2 deletions src/DiffImages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ using ImageCore,
Interpolations,
ChainRulesCore,
LinearAlgebra,
Rotations
Rotations,
ImageBase

using Zygote: @adjoint
using ChainRulesCore: NoTangent

export colorify, channelify
export colorify, channelify, fdiff
include("colors/conversions.jl")
include("geometry/warp.jl")
include("geometry/adjoints.jl")
include("ImageBase.jl/fdiff.jl")

end
9 changes: 9 additions & 0 deletions src/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images
# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case
@adjoint function fdiff(A::AbstractArray; kwargs...)
y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...)
function pullback(Δ)
return (fill(Δ, size(A)),)
end
return (y, pullback)
end
62 changes: 62 additions & 0 deletions test/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using ImageBase.FiniteDiff: fdiff, fdiff!
@testset "fdiff" begin
# Base.diff doesn't promote integer to float
@test ImageBase.FiniteDiff.maybe_floattype(Int) == Int
@test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32
@test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32}
@testset "NumericalTests" begin
a = reshape(collect(1:9), 3, 3)
b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2]
b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6]
b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1]
b_bd_2 = [-6 3 3; -6 3 3; -6 3 3]
out = similar(a)

@test fdiff(a, dims=1) == b_fd_1
@test fdiff(a, dims=2) == b_fd_2
@test fdiff(a, dims=1, rev=true) == b_bd_1
@test fdiff(a, dims=2, rev=true) == b_bd_2
fdiff!(out, a, dims=1)
@test out == b_fd_1
fdiff!(out, a, dims=2)
@test out == b_fd_2
fdiff!(out, a, dims=1, rev=true)
@test out == b_bd_1
fdiff!(out, a, dims=2, rev=true)
@test out == b_bd_2
end
@testset "Differentiability" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9; 6 18 27; 9 27 54; 12 36 81]
@testset "Testing basic fdiff" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with rev" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2))
end
@testset "Testing fdiff with boundary condition" begin
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:periodic))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true,boundary=:zero))[1] == ones(Float64,size(a_fd_2))
end
end
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using Test,
FiniteDifferences,
ChainRulesCore,
CoordinateTransformations,
Rotations
Rotations,
ImageBase

@testset "DiffImages" begin
@info "Testing Colorspace modules"
Expand All @@ -23,4 +24,8 @@ using Test,
@testset "Warps" begin
include("geometry/warp.jl")
end
@info "Testing ImageBase modules"
@testset "FiniteDifferences" begin
include("ImageBase.jl/fdiff.jl")
end
end