Skip to content

Commit ba8468d

Browse files
Merge pull request #20247 from stevengj/viewsfix
fixes for at-view and at-views
2 parents 073ad9f + 7aa23e0 commit ba8468d

File tree

2 files changed

+91
-26
lines changed

2 files changed

+91
-26
lines changed

base/subarray.jl

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -333,25 +333,32 @@ should transform to
333333
A[B[endof(B)]]
334334
335335
"""
336-
function replace_ref_end!(ex,withex=nothing)
336+
replace_ref_end!(ex) = replace_ref_end_!(ex, nothing)[1]
337+
# replace_ref_end_!(ex,withex) returns (new ex, whether withex was used)
338+
function replace_ref_end_!(ex, withex)
339+
used_withex = false
337340
if isa(ex,Symbol) && ex == :end
338341
withex === nothing && error("Invalid use of end")
339-
return withex
342+
return withex, true
340343
elseif isa(ex,Expr)
341344
if ex.head == :ref
342-
S = ex.args[1] = replace_ref_end!(ex.args[1],withex)
345+
ex.args[1], used_withex = replace_ref_end_!(ex.args[1],withex)
346+
S = isa(ex.args[1],Symbol) ? ex.args[1]::Symbol : gensym(:S) # temp var to cache ex.args[1] if needed
347+
used_S = false # whether we actually need S
343348
# new :ref, so redefine withex
344349
nargs = length(ex.args)-1
345350
if nargs == 0
346-
return ex
351+
return ex, used_withex
347352
elseif nargs == 1
348353
# replace with endof(S)
349-
ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S)))
354+
ex.args[2], used_S = replace_ref_end_!(ex.args[2],:($endof($S)))
350355
else
351356
n = 1
352357
J = endof(ex.args)
353358
for j = 2:J-1
354-
exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n)))
359+
exj, used = replace_ref_end_!(ex.args[j],:($size($S,$n)))
360+
used_S |= used
361+
ex.args[j] = exj
355362
if isa(exj,Expr) && exj.head == :...
356363
# splatted object
357364
exjs = exj.args[1]
@@ -364,16 +371,23 @@ function replace_ref_end!(ex,withex=nothing)
364371
n += 1
365372
end
366373
end
367-
ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n)))
374+
ex.args[J], used = replace_ref_end_!(ex.args[J],:($trailingsize($S,$n)))
375+
used_S |= used
376+
end
377+
if used_S && S !== ex.args[1]
378+
S0 = ex.args[1]
379+
ex.args[1] = S
380+
ex = Expr(:let, ex, :($S = $S0))
368381
end
369382
else
370383
# recursive search
371384
for i = eachindex(ex.args)
372-
ex.args[i] = replace_ref_end!(ex.args[i],withex)
385+
ex.args[i], used = replace_ref_end_!(ex.args[i],withex)
386+
used_withex |= used
373387
end
374388
end
375389
end
376-
ex
390+
ex, used_withex
377391
end
378392

379393
"""
@@ -385,9 +399,15 @@ an assignment (e.g. `@view(A[1,2:end]) = ...`). See also [`@views`](@ref)
385399
to switch an entire block of code to use views for slicing.
386400
"""
387401
macro view(ex)
388-
if isa(ex, Expr) && ex.head == :ref
402+
if Meta.isexpr(ex, :ref)
389403
ex = replace_ref_end!(ex)
390-
Expr(:&&, true, esc(Expr(:call,:(Base.view),ex.args...)))
404+
if Meta.isexpr(ex, :ref)
405+
ex = Expr(:call, view, ex.args...)
406+
else # ex replaced by let ...; foo[...]; end
407+
assert(Meta.isexpr(ex, :let) && Meta.isexpr(ex.args[1], :ref))
408+
ex.args[1] = Expr(:call, view, ex.args[1].args...)
409+
end
410+
Expr(:&&, true, esc(ex))
391411
else
392412
throw(ArgumentError("Invalid use of @view macro: argument must be a reference expression A[...]."))
393413
end
@@ -404,21 +424,53 @@ end
404424
@propagate_inbounds maybeview(A) = getindex(A)
405425
@propagate_inbounds maybeview(A::AbstractArray) = getindex(A)
406426

427+
# _views implements the transformation for the @views macro.
428+
# @views calls esc(_views(...)) to work around #20241,
429+
# so any function calls we insert (to maybeview, or to
430+
# size and endof in replace_ref_end!) must be interpolated
431+
# as values rather than as symbols to ensure that they are called
432+
# from Base rather than from the caller's scope.
407433
_views(x) = x
408-
_views(x::Symbol) = esc(x)
409434
function _views(ex::Expr)
410435
if ex.head in (:(=), :(.=))
411-
# don't use view on the lhs of an assignment
412-
Expr(ex.head, esc(ex.args[1]), _views(ex.args[2]))
436+
# don't use view for ref on the lhs of an assignment,
437+
# but still use views for the args of the ref:
438+
lhs = ex.args[1]
439+
Expr(ex.head, Meta.isexpr(lhs, :ref) ?
440+
Expr(:ref, _views.(lhs.args)...) : _views(lhs),
441+
_views(ex.args[2]))
413442
elseif ex.head == :ref
414-
ex = replace_ref_end!(ex)
415-
Expr(:call, :maybeview, _views.(ex.args)...)
443+
Expr(:call, maybeview, _views.(ex.args)...)
416444
else
417445
h = string(ex.head)
418-
if last(h) == '='
419-
# don't use view on the lhs of an op-assignment
420-
Expr(first(h) == '.' ? :(.=) : :(=), esc(ex.args[1]),
421-
Expr(:call, esc(Symbol(h[1:end-1])), _views.(ex.args)...))
446+
# don't use view on the lhs of an op-assignment a[i...] += ...
447+
if last(h) == '=' && Meta.isexpr(ex.args[1], :ref)
448+
lhs = ex.args[1]
449+
450+
# temp vars to avoid recomputing a and i,
451+
# which will be assigned in a let block:
452+
a = gensym(:a)
453+
i = [gensym(:i) for k = 1:length(lhs.args)-1]
454+
455+
# for splatted indices like a[i, j...], we need to
456+
# splat the corresponding temp var.
457+
I = similar(i, Any)
458+
for k = 1:length(i)
459+
if Meta.isexpr(lhs.args[k+1], :...)
460+
I[k] = Expr(:..., i[k])
461+
lhs.args[k+1] = lhs.args[k+1].args[1] # unsplat
462+
else
463+
I[k] = i[k]
464+
end
465+
end
466+
467+
Expr(:let,
468+
Expr(first(h) == '.' ? :(.=) : :(=), :($a[$(I...)]),
469+
Expr(:call, Symbol(h[1:end-1]),
470+
:($maybeview($a, $(I...))),
471+
_views.(ex.args[2:end])...)),
472+
:($a = $(_views(lhs.args[1]))),
473+
[:($(i[k]) = $(_views(lhs.args[k+1]))) for k=1:length(i)]...)
422474
else
423475
Expr(ex.head, _views.(ex.args)...)
424476
end
@@ -439,5 +491,5 @@ that appear explicitly in the given `expression`, not array slicing that
439491
occurs in functions called by that code.
440492
"""
441493
macro views(x)
442-
_views(x)
494+
esc(_views(replace_ref_end!(x)))
443495
end

test/subarray.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,6 @@ end
473473
@test collect(view(view(reshape(1:13^3, 13, 13, 13), 3:7, 6:6, :), 1:2:5, :, 1:2:5)) ==
474474
cat(3,[68,70,72],[406,408,410],[744,746,748])
475475

476-
477-
478476
# tests @view (and replace_ref_end!)
479477
X = reshape(1:24,2,3,4)
480478
Y = 4:-1:1
@@ -494,10 +492,16 @@ u = (1,2:3)
494492
@test X[(1,)...,(2,)...,2:end] == @view X[(1,)...,(2,)...,2:end]
495493

496494
# test macro hygiene
497-
let size=(x,y)-> error("should not happen")
495+
let size=(x,y)-> error("should not happen"), Base=nothing
498496
@test X[1:end,2,2] == @view X[1:end,2,2]
499497
end
500498

499+
# test that side effects occur only once
500+
let foo = [X]
501+
@test X[2:end-1] == @view (push!(foo,X)[1])[2:end-1]
502+
@test foo == [X, X]
503+
end
504+
501505
# test @views macro
502506
@views let f!(x) = x[1:end-1] .+= x[2:end].^2
503507
x = [1,2,3,4]
@@ -512,6 +516,16 @@ end
512516
@test x == [5,6,19,4]
513517
f!(x[3:end])
514518
@test x == [5,6,35,4]
519+
x[Y[2:3]] .= 7:8
520+
@test x == [5,8,7,4]
521+
x[(3,)..., ()...] .+= 3
522+
@test x == [5,8,10,4]
523+
i = Int[]
524+
# test that lhs expressions in update operations are evaluated only once:
525+
x[push!(i,4)[1]] += 5
526+
@test x == [5,8,10,9] && i == [4]
527+
x[push!(i,3)[end]] += 2
528+
@test x == [5,8,12,9] && i == [4,3]
515529
end
516530
@views @test isa(X[1:3], SubArray)
517531
@test X[1:end] == @views X[1:end]
@@ -523,9 +537,8 @@ end
523537
@test X[1:end,2,Y[2:end]] == @views X[1:end,2,Y[2:end]]
524538
@test X[u...,2:end] == @views X[u...,2:end]
525539
@test X[(1,)...,(2,)...,2:end] == @views X[(1,)...,(2,)...,2:end]
526-
527540
# test macro hygiene
528-
let size=(x,y)-> error("should not happen")
541+
let size=(x,y)-> error("should not happen"), Base=nothing
529542
@test X[1:end,2,2] == @views X[1:end,2,2]
530543
end
531544

0 commit comments

Comments
 (0)