diff --git a/Project.toml b/Project.toml index c0e8e71df..1c45e5a0f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.62" +version = "0.7.63" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 33eae96d1..2f5b5003a 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -10,9 +10,10 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} y = sum(x; dims=dims) function sum_pullback(ȳ) # broadcasting the two works out the size no-matter `dims` - x̄ = broadcast(x, ȳ) do xi, ȳi - ȳi - end + x̄ = InplaceableThunk( + @thunk(broadcast(last∘tuple, x, ȳ)), + x -> x .+= ȳ + ) return (NO_FIELDS, x̄) end return y, sum_pullback @@ -29,7 +30,9 @@ function frule( ∂y = if dims isa Colon 2 * real(dot(x, ẋ)) elseif VERSION ≥ v"1.2" # multi-iterator mapreduce introduced in v1.2 - 2 * mapreduce(_realconjtimes, +, x, ẋ; dims=dims) + mapreduce(+, x, ẋ; dims=dims) do xi, dxi + 2 * _realconjtimes(xi, dxi) + end else 2 * sum(_realconjtimes.(x, ẋ); dims=dims) end @@ -44,7 +47,11 @@ function rrule( ) where {T<:Union{Real,Complex}} y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - return (NO_FIELDS, DoesNotExist(), 2 .* real.(ȳ) .* x) + x_thunk = InplaceableThunk( + @thunk(2 .* real.(ȳ) .* x), + dx -> dx .+= 2 .* real.(ȳ) .* x + ) + return (NO_FIELDS, DoesNotExist(), x_thunk) end return y, sum_abs2_pullback end