From a3798232ce19922239fb8e8bb2933b559ac26da2 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 10 May 2022 17:50:20 +0530 Subject: [PATCH] fdiff adjoint added --- Project.toml | 1 + src/DiffImages.jl | 6 ++++-- src/ImageBase.jl/fdiff.jl | 11 ++++++++++ test/ImageBase.jl/fdiff.jl | 44 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 +++++- 5 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 src/ImageBase.jl/fdiff.jl create mode 100644 test/ImageBase.jl/fdiff.jl diff --git a/Project.toml b/Project.toml index df4ed07..5b34bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" +ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" diff --git a/src/DiffImages.jl b/src/DiffImages.jl index 4d2f15d..614b150 100644 --- a/src/DiffImages.jl +++ b/src/DiffImages.jl @@ -7,14 +7,16 @@ using ImageCore, Interpolations, ChainRulesCore, LinearAlgebra, - Rotations + Rotations, + ImageBase using Zygote: @adjoint using ChainRulesCore: NoTangent -export colorify, channelify +export colorify, channelify, fdiff include("colors/conversions.jl") include("geometry/warp.jl") include("geometry/adjoints.jl") +include("ImageBase.jl/fdiff.jl") end diff --git a/src/ImageBase.jl/fdiff.jl b/src/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..d2a59f5 --- /dev/null +++ b/src/ImageBase.jl/fdiff.jl @@ -0,0 +1,11 @@ +# TODO(arcAman07): support RGB inputs, currently works only for GrayScale Images +# TODO(arcAman07): support N dimensional case, currently works only for 2 dimensional case +@adjoint function fdiff(A::AbstractArray; kwargs...) + y = fdiff!(similar(A, maybe_floattype(eltype(A))), A; kwargs...) + final = similar(A, eltype(A)) + function pullback(Δ) + fill!(final, Δ) + return (final,) + end + return (y, pullback) +end \ No newline at end of file diff --git a/test/ImageBase.jl/fdiff.jl b/test/ImageBase.jl/fdiff.jl new file mode 100644 index 0000000..6208b99 --- /dev/null +++ b/test/ImageBase.jl/fdiff.jl @@ -0,0 +1,44 @@ +using ImageBase.FiniteDiff: fdiff, fdiff! +@testset "fdiff" begin + # Base.diff doesn't promote integer to float + @test ImageBase.FiniteDiff.maybe_floattype(Int) == Int + @test ImageBase.FiniteDiff.maybe_floattype(N0f8) == Float32 + @test ImageBase.FiniteDiff.maybe_floattype(RGB{N0f8}) == RGB{Float32} + @testset "NumericalTests" begin + a = reshape(collect(1:9), 3, 3) + b_fd_1 = [1 1 1; 1 1 1; -2 -2 -2] + b_fd_2 = [3 3 -6; 3 3 -6; 3 3 -6] + b_bd_1 = [-2 -2 -2; 1 1 1; 1 1 1] + b_bd_2 = [-6 3 3; -6 3 3; -6 3 3] + out = similar(a) + + @test fdiff(a, dims=1) == b_fd_1 + @test fdiff(a, dims=2) == b_fd_2 + @test fdiff(a, dims=1, rev=true) == b_bd_1 + @test fdiff(a, dims=2, rev=true) == b_bd_2 + fdiff!(out, a, dims=1) + @test out == b_fd_1 + fdiff!(out, a, dims=2) + @test out == b_fd_2 + fdiff!(out, a, dims=1, rev=true) + @test out == b_bd_1 + fdiff!(out, a, dims=2, rev=true) + @test out == b_bd_2 + end + @testset "Differentiability" begin + a_fd_1 = [2 4 8; 3 9 27; 4 16 64] + a_fd_2 = [3 6 9 12; 6 18 27 36; 9 27 54 81; 12 36 81 144] + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=1, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2,rev=true, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_1,dims=2, boundary=:zero))[1] == ones(Float64,size(a_fd_1)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=2,rev=true))[1] == ones(Float64,size(a_fd_2)) + @test Zygote.gradient(x -> sum(x),fdiff(a_fd_2,dims=1,rev=true))[1] == ones(Float64,size(a_fd_2)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 75e60b8..b116311 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,8 @@ using Test, FiniteDifferences, ChainRulesCore, CoordinateTransformations, - Rotations + Rotations, + ImageBase @testset "DiffImages" begin @info "Testing Colorspace modules" @@ -23,4 +24,8 @@ using Test, @testset "Warps" begin include("geometry/warp.jl") end + @info "Testing ImageBase modules" + @testset "FiniteDifferences" begin + include("ImageBase.jl/fdiff.jl") + end end