-
Notifications
You must be signed in to change notification settings - Fork 56
Description
Sorry for the lack of minimal example here, I'll add it when I get the chance.
I recently tried to differentiate through operations on StaticArrays.SVectors and hit two issues, both of which were easily fixable once diagnosed (unless these solutions cause other problems I'm not aware of):
-
StaticArrays defines a catch-all
+(::StaticArray, ::AbstractArray)and ReverseDiff defines a catch-all+(::AbstractArray, ::TrackedArray). So if+(::StaticArray, ::TrackedArray)is called, Julia complains that the dispatch rule is ambiguous. (Same for the other order of operands too.) The same is the case for-, and I'm guessing also for*. -
This line
Line 225 in 71c5ac0
capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t)
assumes thatsimilar(t)supportssetindex!. In the case ofStaticArrays, this assumption breaks ifisbitstype(eltype(t))isfalse, and in particular iftis anStaticArrayofTrackedReals. Similar issue on this line:
Line 166 in 71c5ac0
@inline deriv!(t::TrackedArray, v::AbstractArray) = (copyto!(deriv(t), v); nothing)
I was able to fix these locally by:
-
Explicitly defining
+(::StaticArray, ::TrackedArray)and+(::TrackedArray, ::StaticArray), to disambiguate the dispatch rulea similarly for-(and probably also for*, though I didn't have the need for that). In order to not encounter thesetindex!issue, I had to cast theStaticArrayto anArrayfirst. -
- Redefining
ReverseDiff.captureonStaticArrays to do the same thing without mutation (usingmap+ array conversion instead ofsimilar+map!). As an aside, if this is equally efficient, seems like we might as well make this the default definition. - Redefining
deriv!(::StaticArray ::AbstractArray)to applyderiv!elementwise, like it would if the argument were aTuple.
- Redefining