diff --git a/Project.toml b/Project.toml index af3b030ab..680ead345 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.41" +version = "0.9.42" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/accumulation.jl b/src/accumulation.jl index 5c608550b..4bcc5c33f 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -24,6 +24,8 @@ function add!!(x, t::InplaceableThunk) end end +add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y)) + function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N return if is_inplaceable_destination(x) x .+= y diff --git a/test/accumulation.jl b/test/accumulation.jl index 8b711d523..1ede27473 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -87,15 +87,13 @@ end end - @testset "InplaceableThunk" begin - ithunk = InplaceableThunk( - @thunk(-1.0*ones(2, 2)), - x -> x .-= ones(2, 2) - ) - + @testset "AbstractThunk $(typeof(thunk))" for thunk in ( + @thunk(-1.0*ones(2, 2)), + InplaceableThunk(@thunk(-1.0*ones(2, 2)), x -> x .-= ones(2, 2)), + ) @testset "in place" begin accumuland = [1.0 2.0; 3.0 4.0] - ret = add!!(accumuland, ithunk) + ret = add!!(accumuland, thunk) @test ret == [0.0 1.0; 2.0 3.0] # must return right answer @test ret === accumuland # must be same object end @@ -103,7 +101,7 @@ @testset "out of place" begin accumuland = @SMatrix [1.0 2.0; 3.0 4.0] - ret = add!!(accumuland, ithunk) + ret = add!!(accumuland, thunk) @test ret == [0.0 1.0; 2.0 3.0] # must return right answer @test ret !== accumuland # must not be same object @test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated