From 7dc0e85ddb8bc2796b379dc7c247ae55dbee8841 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 May 2021 01:21:39 -0400 Subject: [PATCH 01/15] cumprod, take 1 --- src/rulesets/Base/mapreduce.jl | 90 +++++++++++++++++++++++++++++++++ test/rulesets/Base/mapreduce.jl | 35 +++++++++++++ 2 files changed, 125 insertions(+) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 741afe569..56b6fdbbb 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -218,3 +218,93 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero dx[i_zero] += p_rest * dy return end + +##### +##### `cumprod` +##### + +function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims=1) + y = cumprod(x; dims=dims) # does nothing unless dims == 1 + function cumprod_pullback_1(dy) + dx_thunk = InplaceableThunk( + @thunk if dims == 1 + ∇cumprod(x, dy, y) + else + dy + end + , + dx -> if dims == 1 + ∇cumprod!(dx, x, dy, y) + else + dx .+= dy + end + ) + return (NO_FIELDS, dx_thunk) + end + return y, cumprod_pullback_1 +end + +function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims) + y = cumprod(x; dims=dims) + @assert dims isa Integer + # vald = Val(dims) + function cumprod_pullback_2(dy) + dx_thunk = InplaceableThunk( + @thunk if dims <= ndims(x) + ∇cumprod_dim(dims, x, dy, y) + else + dy + end + , + dx -> if dims <= ndims(x) + ∇cumprod_dim!(dx, dims, x, dy, y) + else + dx .+= dy + end + ) + return (NO_FIELDS, dx_thunk) + end + return y, cumprod_pullback_2 +end + +function ∇cumprod_dim(dim::Integer, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) + T = promote_type(eltype(x), eltype(dy)) + dx = fill!(similar(x, T, axes(x)), zero(T)) + ∇cumprod_dim!(dx, dim, x, dy, y) + return dx + end + +function ∇cumprod_dim!(dx::AbstractArray, dim::Integer, x::AbstractArray, dy, y) + iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) # type instability! + for ind in Iterators.product(iters...) + @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) + end + return dx +end + +function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) + T = promote_type(eltype(x), eltype(dy)) + dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices + ∇cumprod!(dx, x, dy, y) + return dx +end + +function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y) + lo, hi = firstindex(x), lastindex(x) + z = something(findfirst(iszero, x), hi+1) + @inbounds for i in lo:z-1 + ixi = 1/x[i] + for k in i:z-1 + dx[i] += y[k] * dy[k] * ixi + end + end + @inbounds if z != hi+1 + yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z) + dx[z] += yk * dy[z] + for k in (z+1):hi + yk *= x[k] + dx[z] += yk * dy[k] + end + end + return dx +end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 91926a671..a8bb5f806 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -154,3 +154,38 @@ end end # prod end + + + + @testset "cumprod" begin + v = randn(9) + test_rrule(cumprod, v) + v[3] = 0 + test_rrule(cumprod, v) + v[6] = 0 + test_rrule(cumprod, v) + + @testset "higher dimensions, dims=$dims" for dims in (1,2,3) + m = rand(4,5) + test_rrule(cumprod, m; fkwargs=(;dims=dims)) + m[2,2] = 0 + m[2,4] = 0 + test_rrule(cumprod, m; fkwargs=(;dims=dims)) + + t = randn(3,3,3) + test_rrule(cumprod, x; fkwargs=(;dims=dims)) + end + + @testset "types" begin + back = unthunk(rrule(cumprod, [1, 2, 3])[2]) + @test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1] + + back = unthunk(rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2]) + @test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3] + + @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) + + back = unthunk(rrule(cumprod, Diagonal([1, 2]); dims=1)[2]) + @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 3/2; 1/2 0] + end + end From 4b4c3470d77a99e034b9e35aa49a257293c87513 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 May 2021 10:26:18 -0400 Subject: [PATCH 02/15] fix a type instability --- src/rulesets/Base/mapreduce.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 56b6fdbbb..99bd85651 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -247,17 +247,17 @@ end function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims) y = cumprod(x; dims=dims) @assert dims isa Integer - # vald = Val(dims) + vald = Val(Int(dims)) # else ∇cumprod_dim! will be type unstable function cumprod_pullback_2(dy) dx_thunk = InplaceableThunk( @thunk if dims <= ndims(x) - ∇cumprod_dim(dims, x, dy, y) + ∇cumprod_dim(vald, x, dy, y) else dy end , dx -> if dims <= ndims(x) - ∇cumprod_dim!(dx, dims, x, dy, y) + ∇cumprod_dim!(dx, vald, x, dy, y) else dx .+= dy end @@ -267,21 +267,34 @@ function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims) return y, cumprod_pullback_2 end -function ∇cumprod_dim(dim::Integer, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) +function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim} T = promote_type(eltype(x), eltype(dy)) dx = fill!(similar(x, T, axes(x)), zero(T)) - ∇cumprod_dim!(dx, dim, x, dy, y) + ∇cumprod_dim!(dx, vald, x, dy, y) return dx end -function ∇cumprod_dim!(dx::AbstractArray, dim::Integer, x::AbstractArray, dy, y) - iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) # type instability! +function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} + iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) for ind in Iterators.product(iters...) @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) end return dx end +#= + +julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] + 86.333 μs (2007 allocations: 67.54 KiB) + +julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with 1 hard-coded + 5.417 μs (6 allocations: 15.95 KiB) + +julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with Val(dim) + 5.423 μs (6 allocations: 15.95 KiB) + +=# + function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) T = promote_type(eltype(x), eltype(dy)) dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices From c5db48774c059d9548fd30147db777aad37a3e93 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 27 May 2021 22:10:33 -0400 Subject: [PATCH 03/15] tidy & fix tests --- src/rulesets/Base/mapreduce.jl | 16 ++-------------- test/rulesets/Base/mapreduce.jl | 17 ++++++++--------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 99bd85651..746a3c4a6 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -247,16 +247,17 @@ end function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims) y = cumprod(x; dims=dims) @assert dims isa Integer - vald = Val(Int(dims)) # else ∇cumprod_dim! will be type unstable function cumprod_pullback_2(dy) dx_thunk = InplaceableThunk( @thunk if dims <= ndims(x) + vald = Val(Int(dims)) ∇cumprod_dim(vald, x, dy, y) else dy end , dx -> if dims <= ndims(x) + vald = Val(Int(dims)) ∇cumprod_dim!(dx, vald, x, dy, y) else dx .+= dy @@ -282,19 +283,6 @@ function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) return dx end -#= - -julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] - 86.333 μs (2007 allocations: 67.54 KiB) - -julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with 1 hard-coded - 5.417 μs (6 allocations: 15.95 KiB) - -julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with Val(dim) - 5.423 μs (6 allocations: 15.95 KiB) - -=# - function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) T = promote_type(eltype(x), eltype(dy)) dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index a8bb5f806..1c9e3edbb 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -31,7 +31,6 @@ dy = sum(x; dims=dims) ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3] @test size(ddy) == size(dy) - end end @testset "sum abs2" begin @@ -155,10 +154,9 @@ end # prod end - - +@testset "Accumulations" begin @testset "cumprod" begin - v = randn(9) + v = round.(10 .* randn(9), sigdigits=3) test_rrule(cumprod, v) v[3] = 0 test_rrule(cumprod, v) @@ -166,26 +164,27 @@ end test_rrule(cumprod, v) @testset "higher dimensions, dims=$dims" for dims in (1,2,3) - m = rand(4,5) + m = round.(10 .* randn(4,5), sigdigits=3) test_rrule(cumprod, m; fkwargs=(;dims=dims)) m[2,2] = 0 m[2,4] = 0 test_rrule(cumprod, m; fkwargs=(;dims=dims)) - t = randn(3,3,3) - test_rrule(cumprod, x; fkwargs=(;dims=dims)) + t = round.(10 .* randn(3,3,3), sigdigits=3) + test_rrule(cumprod, t; fkwargs=(;dims=dims)) end @testset "types" begin - back = unthunk(rrule(cumprod, [1, 2, 3])[2]) + back = unthunk(rrule(cumprod, [1, 2, 3])[2]) # allow integer input @test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1] back = unthunk(rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2]) @test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3] - @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) + @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails back = unthunk(rrule(cumprod, Diagonal([1, 2]); dims=1)[2]) @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 3/2; 1/2 0] end end +end From fac481cbaf34938ad80636604ce7da7be2f17f56 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 27 May 2021 22:51:50 -0400 Subject: [PATCH 04/15] two important at-inline-s --- src/rulesets/Base/mapreduce.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 746a3c4a6..81ef05a87 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -275,7 +275,7 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y return dx end -function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} +@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) for ind in Iterators.product(iters...) @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) @@ -290,7 +290,7 @@ function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) return dx end -function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y) +@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y) lo, hi = firstindex(x), lastindex(x) z = something(findfirst(iszero, x), hi+1) @inbounds for i in lo:z-1 From da566cd6fcb2367afa87d4374ccd794e1974d4a4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 28 May 2021 00:05:18 -0400 Subject: [PATCH 05/15] borrow fast path from Zygote 294 --- src/rulesets/Base/mapreduce.jl | 21 ++++++++++++++++++--- test/rulesets/Base/mapreduce.jl | 5 ++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 81ef05a87..3f88175bb 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -276,13 +276,28 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y end @inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} - iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) - for ind in Iterators.product(iters...) - @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) + if any(iszero, x) + iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) + for ind in Iterators.product(iters...) + @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) + end + else + step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y? + step2 = _reverse!!(_cumsum!!(_reverse!!(step1, dim), dim), dim) + dx .+= step2 ./ x end return dx end +# _rscale!!(A, β) = A .* β +# _rscale!!(A::StridedArray, β) = A .*= β + +_reverse!!(x, dims=1) = reverse(x; dims=dims) +_reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims) + +_cumsum!!(x, dims=1) = cumsum(x; dims=dims) +_cumsum!!(x::StridedArray, dims=1) = cumsum!(x, x; dims=dims) + function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) T = promote_type(eltype(x), eltype(dy)) dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 1c9e3edbb..38edb490f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -165,13 +165,16 @@ end @testset "higher dimensions, dims=$dims" for dims in (1,2,3) m = round.(10 .* randn(4,5), sigdigits=3) - test_rrule(cumprod, m; fkwargs=(;dims=dims)) + test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1) m[2,2] = 0 m[2,4] = 0 test_rrule(cumprod, m; fkwargs=(;dims=dims)) t = round.(10 .* randn(3,3,3), sigdigits=3) test_rrule(cumprod, t; fkwargs=(;dims=dims)) + t[2,2,2] = 0 + t[2,3,3] = 0 + test_rrule(cumprod, t; fkwargs=(;dims=dims)) end @testset "types" begin From a7eede20cd3309cac6eb41188e0412937c676f92 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 28 May 2021 08:27:14 -0400 Subject: [PATCH 06/15] fix 1.0 --- src/rulesets/Base/mapreduce.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 3f88175bb..9eb2c6b6d 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -293,7 +293,11 @@ end # _rscale!!(A::StridedArray, β) = A .*= β _reverse!!(x, dims=1) = reverse(x; dims=dims) -_reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims) +if VERSION >= v"1.6" + _reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims) +else + _reverse!!(x::StridedArray, dims=1) = dims==1 ? reverse!(x) : reverse(x; dims=dims) +end _cumsum!!(x, dims=1) = cumsum(x; dims=dims) _cumsum!!(x::StridedArray, dims=1) = cumsum!(x, x; dims=dims) From 276b51b3079b467be42a7532d4ccf46a2069879c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 28 May 2021 09:52:29 -0400 Subject: [PATCH 07/15] try again for 1.0 --- src/rulesets/Base/mapreduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 9eb2c6b6d..0b80a891c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -296,7 +296,7 @@ _reverse!!(x, dims=1) = reverse(x; dims=dims) if VERSION >= v"1.6" _reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims) else - _reverse!!(x::StridedArray, dims=1) = dims==1 ? reverse!(x) : reverse(x; dims=dims) + _reverse!!(x::StridedVector, dims=1) = dims==1 ? reverse!(x) : x end _cumsum!!(x, dims=1) = cumsum(x; dims=dims) From 2f34b50e0b45229f5bf5559b47bfd6fd2ea429f6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 28 May 2021 18:01:54 -0400 Subject: [PATCH 08/15] remove an accidentally quadratic algorithm --- src/rulesets/Base/mapreduce.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 0b80a891c..daf5f4ac5 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -312,11 +312,10 @@ end @inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y) lo, hi = firstindex(x), lastindex(x) z = something(findfirst(iszero, x), hi+1) - @inbounds for i in lo:z-1 - ixi = 1/x[i] - for k in i:z-1 - dx[i] += y[k] * dy[k] * ixi - end + acc = zero(eltype(dy)) + @inbounds for k in z-1:-1:lo + acc += y[k] * dy[k] + dx[k] += acc / x[k] end @inbounds if z != hi+1 yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z) From 895d442cbb467fa00301da807ed2e3aaf033fad0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 28 May 2021 18:05:51 -0400 Subject: [PATCH 09/15] ...after which, the fast path isn't faster anymore, so delete it. --- src/rulesets/Base/mapreduce.jl | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index daf5f4ac5..628498ca0 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -276,34 +276,15 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y end @inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} - if any(iszero, x) - iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) - for ind in Iterators.product(iters...) - @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) - end - else - step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y? - step2 = _reverse!!(_cumsum!!(_reverse!!(step1, dim), dim), dim) - dx .+= step2 ./ x + iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) + for ind in Iterators.product(iters...) + @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) end return dx end -# _rscale!!(A, β) = A .* β -# _rscale!!(A::StridedArray, β) = A .*= β - -_reverse!!(x, dims=1) = reverse(x; dims=dims) -if VERSION >= v"1.6" - _reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims) -else - _reverse!!(x::StridedVector, dims=1) = dims==1 ? reverse!(x) : x -end - -_cumsum!!(x, dims=1) = cumsum(x; dims=dims) -_cumsum!!(x::StridedArray, dims=1) = cumsum!(x, x; dims=dims) - function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) - T = promote_type(eltype(x), eltype(dy)) + T = promote_type(eltype(x), eltype(dy)) # really needs to allow dy * y / x dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices ∇cumprod!(dx, x, dy, y) return dx From bf2071f2317142dbc9473b4581372f57c7438c54 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 24 Jun 2021 09:18:55 -0400 Subject: [PATCH 10/15] rm some un-thunks --- test/rulesets/Base/mapreduce.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 38edb490f..e3c5cde4f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -178,15 +178,15 @@ end end @testset "types" begin - back = unthunk(rrule(cumprod, [1, 2, 3])[2]) # allow integer input + back = rrule(cumprod, [1, 2, 3])[2] # rule allows integer input, but test_rrule does not @test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1] - back = unthunk(rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2]) + back = rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2] @test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3] - @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails + @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient - back = unthunk(rrule(cumprod, Diagonal([1, 2]); dims=1)[2]) + back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2] @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 3/2; 1/2 0] end end From e110d74258cefa1d11952039e6c004391ebef421 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 24 Jun 2021 10:41:09 -0400 Subject: [PATCH 11/15] kwarg types --- src/rulesets/Base/mapreduce.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 628498ca0..1095a921c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -223,7 +223,7 @@ end ##### `cumprod` ##### -function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims=1) +function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) y = cumprod(x; dims=dims) # does nothing unless dims == 1 function cumprod_pullback_1(dy) dx_thunk = InplaceableThunk( @@ -244,9 +244,8 @@ function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims=1) return y, cumprod_pullback_1 end -function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims) +function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) y = cumprod(x; dims=dims) - @assert dims isa Integer function cumprod_pullback_2(dy) dx_thunk = InplaceableThunk( @thunk if dims <= ndims(x) From d2d708eecf6d6c47b01a26e27794eda5e85cc59a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 16:46:13 -0400 Subject: [PATCH 12/15] update for 1.0 --- src/rulesets/Base/mapreduce.jl | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 1095a921c..dc66134e2 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -225,19 +225,20 @@ end function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) y = cumprod(x; dims=dims) # does nothing unless dims == 1 + project_x = ProjectTo(x) function cumprod_pullback_1(dy) dx_thunk = InplaceableThunk( - @thunk if dims == 1 - ∇cumprod(x, dy, y) - else - dy - end - , dx -> if dims == 1 ∇cumprod!(dx, x, dy, y) else dx .+= dy end + , + @thunk project_x(if dims == 1 + ∇cumprod(x, dy, y) + else + dy + end) ) return (NO_FIELDS, dx_thunk) end @@ -246,21 +247,22 @@ end function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) y = cumprod(x; dims=dims) + project_x = ProjectTo(x) function cumprod_pullback_2(dy) dx_thunk = InplaceableThunk( - @thunk if dims <= ndims(x) - vald = Val(Int(dims)) - ∇cumprod_dim(vald, x, dy, y) - else - dy - end - , dx -> if dims <= ndims(x) vald = Val(Int(dims)) ∇cumprod_dim!(dx, vald, x, dy, y) else dx .+= dy end + , + @thunk project_x(if dims <= ndims(x) + vald = Val(Int(dims)) + ∇cumprod_dim(vald, x, dy, y) + else + dy + end) ) return (NO_FIELDS, dx_thunk) end From a988dee20812c935454107486659531e6d7cf13a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 30 Jul 2021 16:56:53 -0400 Subject: [PATCH 13/15] fixup --- src/rulesets/Base/mapreduce.jl | 10 ++++++---- test/rulesets/Base/mapreduce.jl | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index dc66134e2..9bdca588c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -226,7 +226,8 @@ end function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) y = cumprod(x; dims=dims) # does nothing unless dims == 1 project_x = ProjectTo(x) - function cumprod_pullback_1(dy) + function cumprod_pullback_1(dy_raw) + dy = unthunk(dy_raw) dx_thunk = InplaceableThunk( dx -> if dims == 1 ∇cumprod!(dx, x, dy, y) @@ -240,7 +241,7 @@ function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) dy end) ) - return (NO_FIELDS, dx_thunk) + return (NoTangent(), dx_thunk) end return y, cumprod_pullback_1 end @@ -248,7 +249,8 @@ end function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) y = cumprod(x; dims=dims) project_x = ProjectTo(x) - function cumprod_pullback_2(dy) + function cumprod_pullback_2(dy_raw) + dy = unthunk(dy_raw) dx_thunk = InplaceableThunk( dx -> if dims <= ndims(x) vald = Val(Int(dims)) @@ -264,7 +266,7 @@ function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) dy end) ) - return (NO_FIELDS, dx_thunk) + return (NoTangent(), dx_thunk) end return y, cumprod_pullback_2 end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index e3c5cde4f..7613a22d0 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -187,7 +187,7 @@ end @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2] - @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 3/2; 1/2 0] + @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 0; 0 0] # ProjectTo'd to Diagonal now end end end From 4747dd1462fb7f505e63d99ed6e8508fd7ba0bc7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 25 Aug 2021 21:22:24 -0400 Subject: [PATCH 14/15] missing end --- test/rulesets/Base/mapreduce.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 7613a22d0..082b80270 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -31,6 +31,7 @@ dy = sum(x; dims=dims) ddy = rrule(ChainRules._unsum, x, dy, dims)[2](x)[3] @test size(ddy) == size(dy) + end end @testset "sum abs2" begin From 841b802e80dcdfed88d1ca7eef77ec8076976225 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 27 Aug 2021 09:02:57 -0400 Subject: [PATCH 15/15] v1.11 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f5f7d8d25..5b1cb4cb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.10.0" +version = "1.11.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"