-
-
Notifications
You must be signed in to change notification settings - Fork 2
Add support for CoordinateTransformations transforms #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added RotMatrix, SMatrix adjoints - Removed Homography adjoint
- Add tests for SMatrix, RotMatrix, LinearMap - Made changes to Homography tests to check for the types - Add breaking tests for RotMatrix vs. non-RotMatrix gradients
Codecov Report
@@ Coverage Diff @@
## main #19 +/- ##
==========================================
+ Coverage 82.35% 88.23% +5.88%
==========================================
Files 3 3
Lines 102 102
==========================================
+ Hits 84 90 +6
+ Misses 18 12 -6
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the test passing, nothing too suspicious to me.
the gradients come in a nested form compared to the ones when I take using an AffineMap composed of purely SArrays.
I believe this is expected? If you try to get gradients for any neural network, you get something similar. @DhairyaLGandhi should know more about how to process this. I guess it's time to introduce Functors.jl to handle nested structures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes seems expected. Do we need functors here? Is it even correct to use Functors here?
Btw, very happy to see the rrule of homography be deleted, and consistent code forming. Good job on this PR!
_, ∇ϕ = rrule_via_ad(Zygote.ZygoteRuleConfig(), tform, SVector(p.I)) | ||
∇h, _ = ∇ϕ(∇τ) | ||
∇tform += Tangent{DiffImages.Homography}(H = ∇h.H) | ||
∇tform += Tangent{typeof(tform)}(;NamedTuple{flds}(∇h)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
∇tform += Tangent{typeof(tform)}(;NamedTuple{flds}(∇h)...) | |
∇tform += Tangent{typeof(tform)}(NamedTuple{flds}(∇h)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't do this xD. I wanted to pass the field names too, so have to pass it as kwargs :p
If I splat the namedtuple directly, the values get passed but not the keys.
return y, smat_const_pb | ||
end | ||
|
||
function ChainRulesCore.rrule(p::Type{RotMatrix{N, T, L}}, x) where {N, T, L} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this would be hard to scale to other kinds of transforms if we were to write rules for every transform individually. Are there a countable few of these? Do all of them need special handling or can we AD most? Do these compose to produce other transforms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most transforms can be written directly using an AffineMap
. Others, we can use some existing transform to make that transform. This means we (hopefully) won't have to write any more rules.
We can AD through most, and yes I think we can compose to produce other transforms too.
I was planning the composing thing in the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have a poc that demos what you're talking about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I'll prepare it with the docs.
Use Optimisers.jl along with FluxML/Functors.jl#1 to optimise this. We can do a quick call and I can walk you through structural optimisation. Otherwise make sure you have the |
tfm1 = recenter(RotMatrix(π/8), center(img)) # using RotMatrix | ||
tfm2 = AffineMap(SMatrix{2, 2, Float64, 4}([cos(π/8) -sin(π/8); sin(π/8) cos(π/8)]), tfm1.translation) | ||
|
||
zy1 = Zygote.gradient(f, tfm1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh brilliant, are the warp!
rules working as expected now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup
Since I have made the required changes, can we get this PR merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I'll leave it to @DhairyaLGandhi since there seems one or two unresolved comments.
Through this PR, I have added support for the following -
LinearMap
,AffineMap
andRotMatrix
transforms are now differentiable.Homography
transform to make it on the lines ofCoordinateTransformations.LinearMap
. Now it is differentiable without the rule I had written.Some things that are breaking -
AffineMap
combined with aRotMatrix
, something on the lines ofthe gradients come in a nested form compared to the ones when I take using an
AffineMap
composed of purelySArray
s.They are something like this -
I should fix this in another PR.
This PR should solve #17 .