Skip to content

Conversation

SomTambe
Copy link
Member

Through this PR, I have added support for the following -

  • LinearMap, AffineMap and RotMatrix transforms are now differentiable.
  • I modified the Homography transform to make it on the lines of CoordinateTransformations.LinearMap. Now it is differentiable without the rule I had written.
  • I added tests for all of the following, to make sure everything is consistent.

Some things that are breaking -

  • If I use a AffineMap combined with a RotMatrix, something on the lines of
recenter(RotMatrix/8), center(img))

the gradients come in a nested form compared to the ones when I take using an AffineMap composed of purely SArrays.
They are something like this -

(linear = (mat = (data = [37.49958809560448 -20.572948048408485; 35.68484930615324 -28.041924714099725],),), translation = (data = [1.1544357177976206, -2.9696789827699077],)) # used RotMatrix

(linear = [37.49958809560448 -20.572948048408485; 35.68484930615324 -28.041924714099725], translation = [1.1544357177976206, -2.9696789827699077]) # used StaticArrays

I should fix this in another PR.
This PR should solve #17 .

- Added RotMatrix, SMatrix adjoints
- Removed Homography adjoint
- Add tests for SMatrix, RotMatrix, LinearMap
- Made changes to Homography tests to check for the types
- Add breaking tests for RotMatrix vs. non-RotMatrix gradients
@codecov-commenter
Copy link

codecov-commenter commented Aug 14, 2021

Codecov Report

Merging #19 (1a1d26f) into main (fbeb6e4) will increase coverage by 5.88%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #19      +/-   ##
==========================================
+ Coverage   82.35%   88.23%   +5.88%     
==========================================
  Files           3        3              
  Lines         102      102              
==========================================
+ Hits           84       90       +6     
+ Misses         18       12       -6     
Impacted Files Coverage Δ
src/geometry/adjoints.jl 100.00% <100.00%> (+1.85%) ⬆️
src/geometry/warp.jl 75.00% <100.00%> (+33.82%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fbeb6e4...1a1d26f. Read the comment docs.

Copy link
Collaborator

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the test passing, nothing too suspicious to me.

the gradients come in a nested form compared to the ones when I take using an AffineMap composed of purely SArrays.

I believe this is expected? If you try to get gradients for any neural network, you get something similar. @DhairyaLGandhi should know more about how to process this. I guess it's time to introduce Functors.jl to handle nested structures.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes seems expected. Do we need functors here? Is it even correct to use Functors here?

Btw, very happy to see the rrule of homography be deleted, and consistent code forming. Good job on this PR!

_, ∇ϕ = rrule_via_ad(Zygote.ZygoteRuleConfig(), tform, SVector(p.I))
∇h, _ = ∇ϕ(∇τ)
∇tform += Tangent{DiffImages.Homography}(H = ∇h.H)
∇tform += Tangent{typeof(tform)}(;NamedTuple{flds}(∇h)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
∇tform += Tangent{typeof(tform)}(;NamedTuple{flds}(∇h)...)
∇tform += Tangent{typeof(tform)}(NamedTuple{flds}(∇h)...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't do this xD. I wanted to pass the field names too, so have to pass it as kwargs :p
If I splat the namedtuple directly, the values get passed but not the keys.

return y, smat_const_pb
end

function ChainRulesCore.rrule(p::Type{RotMatrix{N, T, L}}, x) where {N, T, L}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this would be hard to scale to other kinds of transforms if we were to write rules for every transform individually. Are there a countable few of these? Do all of them need special handling or can we AD most? Do these compose to produce other transforms?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most transforms can be written directly using an AffineMap. Others, we can use some existing transform to make that transform. This means we (hopefully) won't have to write any more rules.
We can AD through most, and yes I think we can compose to produce other transforms too.
I was planning the composing thing in the next PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a poc that demos what you're talking about.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'll prepare it with the docs.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Aug 14, 2021

Use Optimisers.jl along with FluxML/Functors.jl#1 to optimise this. We can do a quick call and I can walk you through structural optimisation.

Otherwise make sure you have the Flux.params populated with the parameters you want to differentiate with.

tfm1 = recenter(RotMatrix/8), center(img)) # using RotMatrix
tfm2 = AffineMap(SMatrix{2, 2, Float64, 4}([cos/8) -sin/8); sin/8) cos/8)]), tfm1.translation)

zy1 = Zygote.gradient(f, tfm1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh brilliant, are the warp! rules working as expected now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

@SomTambe
Copy link
Member Author

Since I have made the required changes, can we get this PR merged?
About the functor things, I can merge the related changes in a separate PR.
We can have a short call as Dhairya had suggested. I have been a bit busy preparing the docs and due to academic commitments.

Copy link
Collaborator

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I'll leave it to @DhairyaLGandhi since there seems one or two unresolved comments.

@SomTambe SomTambe merged commit ce21ce7 into main Aug 17, 2021
@DhairyaLGandhi DhairyaLGandhi deleted the som/coordinate-trfms branch August 17, 2021 13:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants