Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand Down
6 changes: 4 additions & 2 deletions src/DiffImages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ using ImageCore,
Interpolations,
ChainRulesCore,
LinearAlgebra,
Rotations
Rotations,
ImageBase

using Zygote: @adjoint
using ChainRulesCore: NoTangent

export colorify, channelify
export colorify, channelify, fdiff
include("colors/conversions.jl")
include("geometry/warp.jl")
include("geometry/adjoints.jl")
include("ImageBase.jl/fdiff.jl")

end
11 changes: 11 additions & 0 deletions src/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images
# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case
@adjoint function fdiff(A::AbstractArray; kwargs...)
y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...)
final = similar(A, eltype(A))
function pullback(Δ)
fill!(final, Δ)
return (final,)
end
return (y, pullback)
end
44 changes: 44 additions & 0 deletions test/ImageBase.jl/fdiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using ImageBase.FiniteDiff: fdiff, fdiff!
@testset "fdiff" begin
# Base.diff doesn't promote integer to float
@test ImageBase.FiniteDiff.maybe_floattype(Int) == Int
@test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32
@test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32}
@testset "NumericalTests" begin
a = reshape(collect(1:9), 3, 3)
b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2]
b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6]
b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1]
b_bd_2 = [-6 3 3; -6 3 3; -6 3 3]
out = similar(a)

@test fdiff(a, dims=1) == b_fd_1
@test fdiff(a, dims=2) == b_fd_2
@test fdiff(a, dims=1, rev=true) == b_bd_1
@test fdiff(a, dims=2, rev=true) == b_bd_2
fdiff!(out, a, dims=1)
@test out == b_fd_1
fdiff!(out, a, dims=2)
@test out == b_fd_2
fdiff!(out, a, dims=1, rev=true)
@test out == b_bd_1
fdiff!(out, a, dims=2, rev=true)
@test out == b_bd_2
end
@testset "Differentiability" begin
a_fd_1 = [2 4 8; 3 9 27; 4 16 64]
a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144]
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2))
@test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2))
end
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using Test,
FiniteDifferences,
ChainRulesCore,
CoordinateTransformations,
Rotations
Rotations,
ImageBase

@testset "DiffImages" begin
@info "Testing Colorspace modules"
Expand All @@ -23,4 +24,8 @@ using Test,
@testset "Warps" begin
include("geometry/warp.jl")
end
@info "Testing ImageBase modules"
@testset "FiniteDifferences" begin
include("ImageBase.jl/fdiff.jl")
end
end