diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 33eae96d1..9a8e84834 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -48,3 +48,17 @@ function rrule( end return y, sum_abs2_pullback end + +#### +#### prod +#### + +function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber} + y = prod(x; dims=dims) + function prod_pullback(ȳ) + # broadcasting the two works out the size no-matter `dims` + x̄ = y .* ȳ ./ x + return (NO_FIELDS, x̄) + end + return y, prod_pullback +end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 8e53ffdec..abecbc05d 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -20,4 +20,11 @@ end end end # sum abs2 + + @testset "prod" begin + test_rrule(prod, randn(5)) + test_rrule(prod, randn(5, 6)) + test_rrule(prod, randn(5, 6); fkwargs=(;dims=2)) + test_rrule(prod, randn(5, 6); fkwargs=(;dims=1)) + end end