diff --git a/Project.toml b/Project.toml index f5f7d8d25..5b1cb4cb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.10.0" +version = "1.11.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 741afe569..9bdca588c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -218,3 +218,96 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero dx[i_zero] += p_rest * dy return end + +##### +##### `cumprod` +##### + +function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) + y = cumprod(x; dims=dims) # does nothing unless dims == 1 + project_x = ProjectTo(x) + function cumprod_pullback_1(dy_raw) + dy = unthunk(dy_raw) + dx_thunk = InplaceableThunk( + dx -> if dims == 1 + ∇cumprod!(dx, x, dy, y) + else + dx .+= dy + end + , + @thunk project_x(if dims == 1 + ∇cumprod(x, dy, y) + else + dy + end) + ) + return (NoTangent(), dx_thunk) + end + return y, cumprod_pullback_1 +end + +function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) + y = cumprod(x; dims=dims) + project_x = ProjectTo(x) + function cumprod_pullback_2(dy_raw) + dy = unthunk(dy_raw) + dx_thunk = InplaceableThunk( + dx -> if dims <= ndims(x) + vald = Val(Int(dims)) + ∇cumprod_dim!(dx, vald, x, dy, y) + else + dx .+= dy + end + , + @thunk project_x(if dims <= ndims(x) + vald = Val(Int(dims)) + ∇cumprod_dim(vald, x, dy, y) + else + dy + end) + ) + return (NoTangent(), dx_thunk) + end + return y, cumprod_pullback_2 +end + +function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim} + T = promote_type(eltype(x), eltype(dy)) + dx = fill!(similar(x, T, axes(x)), zero(T)) + ∇cumprod_dim!(dx, vald, x, dy, y) + return dx + end + +@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim} + iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) + for ind in Iterators.product(iters...) + @views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...]) + end + return dx +end + +function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x)) + T = promote_type(eltype(x), eltype(dy)) # really needs to allow dy * y / x + dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices + ∇cumprod!(dx, x, dy, y) + return dx +end + +@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y) + lo, hi = firstindex(x), lastindex(x) + z = something(findfirst(iszero, x), hi+1) + acc = zero(eltype(dy)) + @inbounds for k in z-1:-1:lo + acc += y[k] * dy[k] + dx[k] += acc / x[k] + end + @inbounds if z != hi+1 + yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z) + dx[z] += yk * dy[z] + for k in (z+1):hi + yk *= x[k] + dx[z] += yk * dy[k] + end + end + return dx +end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 91926a671..082b80270 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -154,3 +154,41 @@ end end # prod end + +@testset "Accumulations" begin + @testset "cumprod" begin + v = round.(10 .* randn(9), sigdigits=3) + test_rrule(cumprod, v) + v[3] = 0 + test_rrule(cumprod, v) + v[6] = 0 + test_rrule(cumprod, v) + + @testset "higher dimensions, dims=$dims" for dims in (1,2,3) + m = round.(10 .* randn(4,5), sigdigits=3) + test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1) + m[2,2] = 0 + m[2,4] = 0 + test_rrule(cumprod, m; fkwargs=(;dims=dims)) + + t = round.(10 .* randn(3,3,3), sigdigits=3) + test_rrule(cumprod, t; fkwargs=(;dims=dims)) + t[2,2,2] = 0 + t[2,3,3] = 0 + test_rrule(cumprod, t; fkwargs=(;dims=dims)) + end + + @testset "types" begin + back = rrule(cumprod, [1, 2, 3])[2] # rule allows integer input, but test_rrule does not + @test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1] + + back = rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2] + @test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3] + + @test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient + + back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2] + @test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 0; 0 0] # ProjectTo'd to Diagonal now + end + end +end