Skip to content

Commit cf9f4bb

Browse files
committed
Fix tests
1 parent 161d227 commit cf9f4bb

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

src/apiutils.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,46 +41,55 @@ end
4141
end
4242

4343
# Only seed indices that are structurally non-zero
44-
_structural_nonzero_indices(x::AbstractArray) = eachindex(x)
45-
function _structural_nonzero_indices(x::UpperTriangular)
44+
structural_eachindex(x::AbstractArray) = structural_eachindex(x, x)
45+
function structural_eachindex(x::AbstractArray, y::AbstractArray)
46+
require_one_based_indexing(x, y)
47+
eachindex(x, y)
48+
end
49+
function structural_eachindex(x::UpperTriangular, y::AbstractArray)
50+
require_one_based_indexing(x, y)
51+
if size(x) != size(y)
52+
throw(DimensionMismatch())
53+
end
4654
n = size(x, 1)
4755
return (CartesianIndex(i, j) for j in 1:n for i in 1:j)
4856
end
49-
function _structural_nonzero_indices(x::LowerTriangular)
57+
function structural_eachindex(x::LowerTriangular, y::AbstractArray)
58+
require_one_based_indexing(x, y)
59+
if size(x) != size(y)
60+
throw(DimensionMismatch())
61+
end
5062
n = size(x, 1)
5163
return (CartesianIndex(i, j) for j in 1:n for i in j:n)
5264
end
53-
_structural_nonzero_indices(x::Diagonal) = diagind(x)
65+
function structural_eachindex(x::Diagonal, y::AbstractArray)
66+
require_one_based_indexing(x, y)
67+
if size(x) != size(y)
68+
throw(DimensionMismatch())
69+
end
70+
return diagind(x)
71+
end
5472

5573
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5674
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
57-
if eachindex(duals) != eachindex(x)
58-
throw(ArgumentError("indices of input array and array of duals are not identical"))
59-
end
60-
for idx in _structural_nonzero_indices(duals)
75+
for idx in structural_eachindex(duals, x)
6176
duals[idx] = Dual{T,V,N}(x[idx], seed)
6277
end
6378
return duals
6479
end
6580

6681
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
6782
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
68-
if eachindex(duals) != eachindex(x)
69-
throw(ArgumentError("indices of input array and array of duals are not identical"))
70-
end
71-
for (i, idx) in enumerate(_structural_nonzero_indices(duals))
83+
for (i, idx) in zip(1:N, structural_eachindex(duals, x))
7284
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
7385
end
7486
return duals
7587
end
7688

7789
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7890
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
79-
if eachindex(duals) != eachindex(x)
80-
throw(ArgumentError("indices of input array and array of duals are not identical"))
81-
end
8291
offset = index - 1
83-
idxs = Iterators.drop(_structural_nonzero_indices(duals), offset)
92+
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
8493
for idx in idxs
8594
duals[idx] = Dual{T,V,N}(x[idx], seed)
8695
end
@@ -89,13 +98,9 @@ end
8998

9099
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
91100
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
92-
if eachindex(duals) != eachindex(x)
93-
throw(ArgumentError("indices of input array and array of duals are not identical"))
94-
end
95101
offset = index - 1
96-
idxs = Iterators.drop(_structural_nonzero_indices(duals), offset)
97-
for (i, idx) in enumerate(idxs)
98-
i > chunksize && break
102+
idxs = Iterators.drop(structural_eachindex(duals, x), offset)
103+
for (i, idx) in zip(1:chunksize, idxs)
99104
duals[idx] = Dual{T,V,N}(x[idx], seeds[i])
100105
end
101106
return duals

src/gradient.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,17 @@ end
6464

6565
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
6666
function extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}
67-
idxs = _structural_nonzero_indices(result)
68-
for (i, idx) in enumerate(idxs)
67+
idxs = structural_eachindex(result)
68+
for (i, idx) in zip(1:npartials(dual), idxs)
6969
result[idx] = partials(T, dual, i)
7070
end
7171
return result
7272
end
7373

7474
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
7575
offset = index - 1
76-
idxs = Iterators.drop(_structural_nonzero_indices(result), offset)
77-
for (i, idx) in enumerate(idxs)
78-
i > chunksize && break
76+
idxs = Iterators.drop(structural_eachindex(result), offset)
77+
for (i, idx) in zip(1:chunksize, idxs)
7978
result[idx] = partials(T, dual, i)
8079
end
8180
return result

0 commit comments

Comments
 (0)