Skip to content

Commit 696837b

Browse files
committed
Don't fetch results that are discarded
1 parent cfc7dfb commit 696837b

File tree

6 files changed

+47
-57
lines changed

6 files changed

+47
-57
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,18 @@ end
7070
# This will turn local AbstractArrays into DArrays
7171
dbc = bcdistribute(bc)
7272

73-
asyncmap(procs(dest)) do p
74-
remotecall_fetch(p) do
73+
@sync for p in procs(dest)
74+
@async remotecall_wait(p) do
7575
# get the indices for the localpart
7676
lpidx = localpartindex(dest)
7777
@assert lpidx != 0
7878
# create a local version of the broadcast, by constructing views
7979
# Note: creates copies of the argument
8080
lbc = bclocal(dbc, dest.indices[lpidx])
8181
Base.copyto!(localpart(dest), lbc)
82-
return nothing
8382
end
8483
end
84+
8585
return dest
8686
end
8787

src/core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function close_by_id(id, pids)
2121
global refs
2222
@sync begin
2323
for p in pids
24-
@async remotecall_fetch(release_localpart, p, id)
24+
@async remotecall_wait(release_localpart, p, id)
2525
end
2626
if !(myid() in pids)
2727
release_localpart(id)

src/darray.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function DArray(id, init, dims, pids, idxs, cuts)
9797

9898
if length(unique(localtypes)) != 1
9999
@sync for p in pids
100-
@async remotecall_fetch(release_localpart, p, id)
100+
@async remotecall_wait(release_localpart, p, id)
101101
end
102102
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
103103
end
@@ -147,8 +147,8 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
147147
end
148148
end
149149

150-
@sync for i = 1:length(pids)
151-
@async remotecall_fetch(construct_localparts, pids[i], init, id, (npids,), pids, idxs, cuts; T=T, A=T)
150+
@sync for p in pids
151+
@async remotecall_wait(construct_localparts, p, init, id, (npids,), pids, idxs, cuts; T=T, A=T)
152152
end
153153

154154
if myid() in pids
@@ -161,9 +161,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
161161
end
162162

163163
function gather(d::DArray{T,1,T}) where T
164-
a=Array{T}(undef, length(procs(d)))
165-
@sync for (i,p) in enumerate(procs(d))
166-
@async a[i] = remotecall_fetch(localpart, p, d)
164+
pids = procs(d)
165+
a = Vector{T}(undef, length(pids))
166+
asyncmap!(a, pids) do p
167+
remotecall_fetch(localpart, p, d)
167168
end
168169
a
169170
end
@@ -195,12 +196,9 @@ function DArray(refs)
195196
dimdist = size(refs)
196197
id = next_did()
197198

198-
npids = [r.where for r in refs]
199199
nsizes = Array{Tuple}(undef, dimdist)
200-
@sync for i in 1:length(refs)
201-
let i=i
202-
@async nsizes[i] = remotecall_fetch(sz_localpart_ref, npids[i], refs[i], id)
203-
end
200+
asyncmap!(nsizes, refs) do r
201+
remotecall_fetch(sz_localpart_ref, r.where, r, id)
204202
end
205203

206204
nindices = Array{NTuple{length(dimdist),UnitRange{Int}}}(undef, dimdist...)
@@ -223,7 +221,7 @@ function DArray(refs)
223221
ncuts = Array{Int,1}[pushfirst!(sort(unique(lastidxs[x,:])), 1) for x in 1:length(dimdist)]
224222
ndims = tuple([sort(unique(lastidxs[x,:]))[end]-1 for x in 1:length(dimdist)]...)
225223

226-
DArray(id, refs, ndims, reshape(npids, dimdist), nindices, ncuts)
224+
DArray(id, refs, ndims, map(r -> r.where, refs), nindices, ncuts)
227225
end
228226

229227
macro DArray(ex0::Expr)
@@ -683,8 +681,8 @@ Base.copy(d::SubDArray) = copyto!(similar(d), d)
683681
Base.copy(d::SubDArray{<:Any,2}) = copyto!(similar(d), d)
684682

685683
function Base.copyto!(dest::SubOrDArray, src::AbstractArray)
686-
asyncmap(procs(dest)) do p
687-
remotecall_fetch(p) do
684+
@sync for p in procs(dest)
685+
@async remotecall_wait(p) do
688686
ldest = localpart(dest)
689687
copyto!(ldest, view(src, localindices(dest)...))
690688
end
@@ -694,8 +692,8 @@ end
694692

695693
function Base.deepcopy(src::DArray)
696694
dest = similar(src)
697-
asyncmap(procs(src)) do p
698-
remotecall_fetch(p) do
695+
@sync for p in procs(src)
696+
@async remotecall_wait(p) do
699697
dest[:L] = deepcopy(src[:L])
700698
end
701699
end
@@ -846,16 +844,17 @@ end
846844

847845
function Base.fill!(A::DArray, x)
848846
@sync for p in procs(A)
849-
@async remotecall_fetch((A,x)->(fill!(localpart(A), x); nothing), p, A, x)
847+
@async remotecall_wait((A,x)->fill!(localpart(A), x), p, A, x)
850848
end
851849
return A
852850
end
853851

854852
using Random
855853

856854
function Random.rand!(A::DArray, ::Type{T}) where T
857-
asyncmap(procs(A)) do p
858-
remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
855+
@sync for p in procs(A)
856+
@async remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
859857
end
858+
return A
860859
end
861860

src/linalg.jl

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ function axpy!(α, x::DArray, y::DArray)
2525
if length(x) != length(y)
2626
throw(DimensionMismatch("vectors must have same length"))
2727
end
28-
asyncmap(procs(y)) do p
29-
@async remotecall_fetch(p) do
28+
@sync for p in procs(y)
29+
@async remotecall_wait(p) do
3030
axpy!(α, localpart(x), localpart(y))
31-
return nothing
3231
end
3332
end
3433
return y
@@ -39,26 +38,22 @@ function dot(x::DVector, y::DVector)
3938
throw(DimensionMismatch(""))
4039
end
4140

42-
results=Any[]
43-
asyncmap(procs(x)) do p
44-
push!(results, remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y))
41+
results = asyncmap(procs(x)) do p
42+
remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y)
4543
end
4644
return reduce(+, results)
4745
end
4846

4947
function norm(x::DArray, p::Real = 2)
50-
results = []
51-
@sync begin
52-
for pp in procs(x)
53-
@async push!(results, remotecall_fetch(() -> norm(localpart(x), p), pp))
54-
end
48+
results = asyncmap(procs(x)) do pp
49+
remotecall_fetch(() -> norm(localpart(x), p), pp)
5550
end
5651
return norm(results, p)
5752
end
5853

5954
function LinearAlgebra.rmul!(A::DArray, x::Number)
6055
@sync for p in procs(A)
61-
@async remotecall_fetch((A,x)->(rmul!(localpart(A), x); nothing), p, A, x)
56+
@async remotecall_wait((A,x)->rmul!(localpart(A), x), p, A, x)
6257
end
6358
return A
6459
end
@@ -104,13 +99,12 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
10499
# Scale y if necessary
105100
if β != one(β)
106101
asyncmap(procs(y)) do p
107-
remotecall_fetch(p) do
102+
remotecall_wait(p) do
108103
if !iszero(β)
109104
rmul!(localpart(y), β)
110105
else
111106
fill!(localpart(y), 0)
112107
end
113-
return nothing
114108
end
115109
end
116110
end
@@ -120,7 +114,7 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
120114
p = y.pids[i]
121115
for j = 1:size(R, 2)
122116
rij = R[i,j]
123-
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
117+
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
124118
end
125119
end
126120

@@ -150,14 +144,13 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab
150144

151145
# Scale y if necessary
152146
if β != one(β)
153-
asyncmap(procs(y)) do p
154-
remotecall_fetch(p) do
147+
@sync for p in procs(y)
148+
@async remotecall_wait(p) do
155149
if !iszero(β)
156150
rmul!(localpart(y), β)
157151
else
158152
fill!(localpart(y), 0)
159153
end
160-
return nothing
161154
end
162155
end
163156
end
@@ -167,7 +160,7 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab
167160
p = y.pids[i]
168161
for j = 1:size(R, 2)
169162
rij = R[i,j]
170-
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
163+
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
171164
end
172165
end
173166
return y
@@ -238,10 +231,10 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
238231
# Scale C if necessary
239232
if β != one(β)
240233
@sync for p in C.pids
241-
if β != zero(β)
242-
@async remotecall_fetch(() -> (rmul!(localpart(C), β); nothing), p)
234+
if iszero(β)
235+
@async remotecall_wait(() -> fill!(localpart(C), 0), p)
243236
else
244-
@async remotecall_fetch(() -> (fill!(localpart(C), 0); nothing), p)
237+
@async remotecall_wait(() -> rmul!(localpart(C), β), p)
245238
end
246239
end
247240
end
@@ -252,7 +245,7 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
252245
p = C.pids[i,k]
253246
for j = 1:size(R, 2)
254247
rijk = R[i,j,k]
255-
@async remotecall_fetch(d -> (add!(localpart(d), fetch(rijk), α); nothing), p, C)
248+
@async remotecall_wait(d -> add!(localpart(d), fetch(rijk), α), p, C)
256249
end
257250
end
258251
end

src/mapreduce.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ import SparseArrays: nnz
66
Base.map(f, d0::DArray, ds::AbstractArray...) = broadcast(f, d0, ds...)
77

88
function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
9-
asyncmap(procs(dest)) do p
10-
remotecall_fetch(p) do
9+
@sync for p in procs(dest)
10+
@async remotecall_fetch(p) do
1111
map!(f, localpart(dest), makelocal(src, localindices(dest)...))
12-
return nothing
1312
end
1413
end
1514
return dest
@@ -41,8 +40,8 @@ function Base.reducedim_initarray(A::DArray, region, v0, ::Type{R}) where {R}
4140
# Store reduction on lowest pids
4241
pids = A.pids[ntuple(i -> i in region ? (1:1) : (:), ndims(A))...]
4342
chunks = similar(pids, Future)
44-
@sync for i in eachindex(pids)
45-
@async chunks[i...] = remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), pids[i...])
43+
asyncmap!(chunks, pids) do p
44+
remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), p)
4645
end
4746
return DArray(chunks)
4847
end
@@ -67,13 +66,12 @@ end
6766
# has been run on each localpart with mapreducedim_within. Eventually, we might
6867
# want to write mapreducedim_between! as a binary reduction.
6968
function mapreducedim_between!(f, op, R::DArray, A::DArray, region)
70-
asyncmap(procs(R)) do p
71-
remotecall_fetch(p, f, op, R, A, region) do f, op, R, A, region
69+
@sync for p in procs(R)
70+
@async remotecall_wait(p, f, op, R, A, region) do f, op, R, A, region
7271
localind = [r for r = localindices(A)]
7372
localind[[region...]] = [1:n for n = size(A)[[region...]]]
7473
B = convert(Array, A[localind...])
7574
Base.mapreducedim!(f, op, localpart(R), B)
76-
nothing
7775
end
7876
end
7977
return R
@@ -170,8 +168,8 @@ function map_localparts(f::Callable, A::Array, DA::DArray)
170168
end
171169

172170
function map_localparts!(f::Callable, d::DArray)
173-
asyncmap(procs(d)) do p
174-
remotecall_fetch((f,d)->(f(localpart(d)); nothing), p, f, d)
171+
@sync for p in procs(d)
172+
@async remotecall_wait((f,d)->f(localpart(d)), p, f, d)
175173
end
176174
return d
177175
end

src/spmd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ function spmd(f, args...; pids=procs(), context=nothing)
243243
ctxt_id = context.id
244244
end
245245
@sync for p in pids
246-
@async remotecall_fetch(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
246+
@async remotecall_wait(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
247247
end
248248
nothing
249249
end

0 commit comments

Comments
 (0)