From f3f43b46d1c1d184dce65905d48b9a11f2dd1fb7 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 26 Dec 2020 16:29:25 +0100 Subject: [PATCH 1/5] tweaks to sum rules --- src/rulesets/Base/mapreduce.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 33eae96d1..0885a4fa5 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -10,9 +10,10 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} y = sum(x; dims=dims) function sum_pullback(ȳ) # broadcasting the two works out the size no-matter `dims` - x̄ = broadcast(x, ȳ) do xi, ȳi - ȳi - end + x̄ = InplaceableThunk( + @thunk(broadcast((_,y1)->y1, x, ȳ)), # last∘tuple + x -> x .+= x̄ + ) return (NO_FIELDS, x̄) end return y, sum_pullback @@ -29,7 +30,9 @@ function frule( ∂y = if dims isa Colon 2 * real(dot(x, ẋ)) elseif VERSION ≥ v"1.2" # multi-iterator mapreduce introduced in v1.2 - 2 * mapreduce(_realconjtimes, +, x, ẋ; dims=dims) + mapreduce(+, x, ẋ; dims=dims) do xi, dxi + 2 * _realconjtimes(xi, dxi) + end else 2 * sum(_realconjtimes.(x, ẋ); dims=dims) end @@ -44,7 +47,11 @@ function rrule( ) where {T<:Union{Real,Complex}} y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - return (NO_FIELDS, DoesNotExist(), 2 .* real.(ȳ) .* x) + x_thunk = InplaceableThunk( + @thunk(2 .* real.(ȳ) .* x), + dx -> dx .+= 2 .* real.(ȳ) .* x + ) + return (NO_FIELDS, DoesNotExist(), x_thunk) end return y, sum_abs2_pullback end From 7d2050c84fe34d958ef3347f280c171e890c1f31 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 6 May 2021 19:11:54 -0400 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: Lyndon White --- src/rulesets/Base/mapreduce.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 0885a4fa5..bf8b57d35 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -11,8 +11,8 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} function sum_pullback(ȳ) # broadcasting the two works out the size no-matter `dims` x̄ = InplaceableThunk( - @thunk(broadcast((_,y1)->y1, x, ȳ)), # last∘tuple - x -> x .+= x̄ + @thunk(broadcast(last∘tuple, x, ȳ)) + x -> x .+= ȳ ) return (NO_FIELDS, x̄) end From 8b9584887ad10779cac2d8923898d4775d79b0a5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 6 May 2021 19:30:29 -0400 Subject: [PATCH 3/5] comma --- src/rulesets/Base/mapreduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index bf8b57d35..2f5b5003a 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -11,7 +11,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number} function sum_pullback(ȳ) # broadcasting the two works out the size no-matter `dims` x̄ = InplaceableThunk( - @thunk(broadcast(last∘tuple, x, ȳ)) + @thunk(broadcast(last∘tuple, x, ȳ)), x -> x .+= ȳ ) return (NO_FIELDS, x̄) From 2b882c1f155f2ad54291f80281d55127b3edce8a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 6 May 2021 20:27:54 -0400 Subject: [PATCH 4/5] v0.7.62 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 091b85f08..2c4004059 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.41" +version = "0.7.62" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 082c73ccc19f633f140a99de968b09c51e2afa62 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 7 May 2021 13:02:09 -0400 Subject: [PATCH 5/5] gone cold one more time... Co-authored-by: David Widmann --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9a732804b..0a7537005 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.62" +version = "0.7.63" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"