Skip to content

Commit 88411b9

Browse files
authored
Merge branch 'master' into compathelper/new_version/2020-11-18-00-11-22-279-2492702306
2 parents 4fe1645 + 5f86408 commit 88411b9

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.4.3"
3+
version = "1.4.4"
44

55
[deps]
66
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
@@ -23,12 +23,13 @@ FunctionWrappers = "1"
2323
MacroTools = "0.5"
2424
NaNMath = "0.3"
2525
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"
26-
StaticArrays = "0.10, 0.11, 0.12"
26+
StaticArrays = "0.10, 0.11, 0.12, 1.0"
2727
julia = "1"
2828

2929
[extras]
3030
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
31+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233

3334
[targets]
34-
test = ["DiffTests", "Test"]
35+
test = ["DiffTests", "FillArrays", "Test"]

src/derivatives/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,17 @@ function get_implementation(bc, f, T, args)
8585
end
8686
function Base.copy(_bc::Broadcasted{TrackedStyle})
8787
bc = remove_not_tracked(_bc)
88-
flattened_bc = Broadcast.flatten(bc)
88+
flattened_bc = Base.Broadcast.flatten(bc)
8989
untracked_bc = broadcast_rebuild(bc)
90-
flattened_untracked_bc = Broadcast.flatten(untracked_bc)
9190
T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)})
92-
f, args = flattened_untracked_bc.f, flattened_bc.args
91+
f, args = flattened_bc.f, flattened_bc.args
9392
implementation = get_implementation(_bc, f, T, args)
9493
if implementation isa Val{:reversediff}
9594
return ∇broadcast(f, args...)
9695
elseif implementation isa Val{:tracker}
9796
return tracker_∇broadcast(f, args...)
9897
else
98+
flattened_untracked_bc = Base.Broadcast.flatten(untracked_bc)
9999
style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes
100100
return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes))
101101
end

test/compat/CompatTests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module CompatTests
2+
3+
using FillArrays, ReverseDiff, Test
4+
5+
@test ReverseDiff.gradient(fill(2.0, 3)) do x
6+
sum(abs2.(x .- Zeros(3)))
7+
end == fill(4.0, 3)
8+
9+
@test ReverseDiff.gradient(fill(2.0, 3)) do x
10+
sum(abs2.(x .- (1:3)))
11+
end == [2, 0, -2]
12+
13+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ println("done (took $t seconds).")
4545
println("running ConfigTests...")
4646
t = @elapsed include(joinpath(TESTDIR, "api/ConfigTests.jl"))
4747
println("done (took $t seconds).")
48+
49+
println("running CompatTests...")
50+
t = @elapsed include(joinpath(TESTDIR, "compat/CompatTests.jl"))
51+
println("done (took $t seconds).")

0 commit comments

Comments
 (0)