-
Notifications
You must be signed in to change notification settings - Fork 64
Description
This is a follow-up to this discussion in JuliaDiff/ChainRules.jl#336.
JuliaDiff/ChainRules.jl#336 improves the array rules for sum by changing the code e.g. (in the case of sum(abs2, x)) from 2 .* real.(ȳ) .* x to
InplaceableThunk(
@thunk(2 .* real.(ȳ) .* x), # val
dx -> dx .+= 2 .* real.(ȳ) .* x # add!(dx)
)This makes two improvements:
- (1) the
valcomputation2 .* real.(ȳ) .* xis now thunked@thunk(2 .* real.(ȳ) .* x) - (2) the
add!accumulation function is nowdx -> dx .+= 2 .* real.(ȳ) .* x
It took me a while to work out why (2) was in improvement. The docs on InplaceableThunks say
add!should be defined such that:ithunk.add!(Δ) = Δ .+= ithunk.valbut it should do this more efficently than simply doing this directly.
Looking at the code above, where val = 2 .* real.(ȳ) .* x, why is add!(dx) = dx .+= 2 .* real.(ȳ) .* x "more efficient" that add!(dx) = dx .+= val? By copying the code for val into the add! function we get a single expression, allowing the broadcast to be "fused", and thereby avoid allocating an intermediate val = 2 .* real.(ȳ) .* x array.
So that's cool! (Aside: there are some good blog posts about Julia's loop fusion and broadcast magic)
But it did mean we had to copy code. This issue is to ask "can we do this without having to copy code?" i.e. it's about API / user-friendliness / reducing code / syntactic stuff (which might in turn make this performance improvement more widely used in our array rules).
I see two options, but perhaps there are others:
(A) create a macro like @inplaceable_thunk
If we did this, code such as
x_thunk = InplaceableThunk(
@thunk(2 .* real.(ȳ) .* x),
dx -> dx .+= 2 .* real.(ȳ) .* x
)could instead be written more succinctly as
x_thunk = @inplaceable_thunk(2 .* real.(ȳ) .* x)(B) have @thunk always return an InplaceableThunk with the add! function defined like above (i.e. copying in the code for val)
I'm not sure if (B) is a valid option. But perhaps it is, if users are expected to go via the add!! function (which checks is_inplaceable_destination).