Skip to content

Commit c4dec1a

Browse files
committed
add an inlining heuristic that helps avoid allocations
don't inline into a function `f` if doing so would put it over the inlining threshhold, and if inlining `f` itself would help avoid tuple allocations.
1 parent 508ed9f commit c4dec1a

File tree

1 file changed

+82
-69
lines changed

1 file changed

+82
-69
lines changed

base/inference.jl

Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,9 +2875,7 @@ function isinlineable(m::Method, src::CodeInfo)
28752875
end
28762876
end
28772877
if !inlineable
2878-
body = Expr(:block)
2879-
body.args = src.code
2880-
inlineable = inline_worthy(body, cost)
2878+
inlineable = inline_worthy_stmts(src.code, cost)
28812879
end
28822880
return inlineable
28832881
end
@@ -3647,7 +3645,10 @@ end
36473645
# static parameters are ok if all the static parameter values are leaf types,
36483646
# meaning they are fully known.
36493647
# `ft` is the type of the function. `f` is the exact function if known, or else `nothing`.
3650-
function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::InferenceState)
3648+
# `pending_stmts` is an array of statements from functions inlined so far, so
3649+
# we can estimate the total size of the enclosing function after inlining.
3650+
function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::InferenceState,
3651+
pending_stmts)
36513652
argexprs = e.args
36523653

36533654
if (f === typeassert || ft typeof(typeassert)) && length(atypes)==3
@@ -3918,6 +3919,31 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
39183919
invoke_data)
39193920
end
39203921

3922+
if !isa(ast, Array{Any,1})
3923+
ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast)
3924+
else
3925+
ast = copy_exprargs(ast)
3926+
end
3927+
ast = ast::Array{Any,1}
3928+
3929+
if sv.bestguess Tuple # check for non-isbits Tuple return
3930+
if !isbits(widenconst(sv.bestguess))
3931+
# See if inlining this call would change the enclosing function
3932+
# from inlineable to not inlineable.
3933+
# This heuristic is applied to functions that return non-bits
3934+
# tuples, since we want to be able to inline those functions to
3935+
# avoid the tuple allocation.
3936+
current_stmts = vcat(sv.src.code, pending_stmts)
3937+
if inline_worthy_stmts(current_stmts)
3938+
append!(current_stmts, ast)
3939+
if !inline_worthy_stmts(current_stmts)
3940+
return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited,
3941+
invoke_data)
3942+
end
3943+
end
3944+
end
3945+
end
3946+
39213947
# create the backedge
39223948
if isa(frame, InferenceState) && !frame.inferred && frame.cached
39233949
# in this case, the actual backedge linfo hasn't been computed
@@ -3940,13 +3966,6 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
39403966

39413967
nm = length(unwrap_unionall(metharg).parameters)
39423968

3943-
if !isa(ast, Array{Any,1})
3944-
ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast)
3945-
else
3946-
ast = copy_exprargs(ast)
3947-
end
3948-
ast = ast::Array{Any,1}
3949-
39503969
body = Expr(:block)
39513970
body.args = ast
39523971
propagate_inbounds = src.propagate_inbounds
@@ -4173,10 +4192,13 @@ function inline_ignore(ex::ANY)
41734192
return isa(ex, Expr) && is_meta_expr(ex::Expr)
41744193
end
41754194

4195+
function inline_worthy_stmts(stmts::Vector{Any}, cost::Integer = 1000)
4196+
body = Expr(:block)
4197+
body.args = stmts
4198+
return inline_worthy(body, cost)
4199+
end
4200+
41764201
function inline_worthy(body::Expr, cost::Integer=1000) # precondition: 0 < cost; nominal cost = 1000
4177-
if popmeta!(body, :noinline)[1]
4178-
return false
4179-
end
41804202
symlim = 1000 + 5_000_000 ÷ cost
41814203
nstmt = 0
41824204
for stmt in body.args
@@ -4224,17 +4246,15 @@ end
42244246
function inlining_pass!(sv::InferenceState)
42254247
eargs = sv.src.code
42264248
i = 1
4249+
stmtbuf = []
42274250
while i <= length(eargs)
42284251
ei = eargs[i]
42294252
if isa(ei, Expr)
4230-
res = inlining_pass(ei, sv)
4231-
eargs[i] = res[1]
4232-
if isa(res[2], Array)
4233-
sts = res[2]::Array{Any,1}
4234-
for j = 1:length(sts)
4235-
insert!(eargs, i, sts[j])
4236-
i += 1
4237-
end
4253+
eargs[i] = inlining_pass(ei, sv, stmtbuf, 1)
4254+
if !isempty(stmtbuf)
4255+
splice!(eargs, i:i-1, stmtbuf)
4256+
i += length(stmtbuf)
4257+
empty!(stmtbuf)
42384258
end
42394259
end
42404260
i += 1
@@ -4243,16 +4263,17 @@ end
42434263

42444264
const corenumtype = Union{Int32, Int64, Float32, Float64}
42454265

4246-
function inlining_pass(e::Expr, sv::InferenceState)
4266+
# return inlined replacement for `e`, inserting new needed statements
4267+
# at index `ins` in `stmts`.
4268+
function inlining_pass(e::Expr, sv::InferenceState, stmts, ins)
42474269
if e.head === :method
42484270
# avoid running the inlining pass on function definitions
4249-
return (e, ())
4271+
return e
42504272
end
42514273
eargs = e.args
42524274
if length(eargs) < 1
4253-
return (e, ())
4275+
return e
42544276
end
4255-
stmts = []
42564277
arg1 = eargs[1]
42574278
isccall = false
42584279
i0 = 1
@@ -4267,6 +4288,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
42674288
i0 = 5
42684289
end
42694290
has_stmts = false # needed to preserve order-of-execution
4291+
prev_stmts_length = length(stmts)
42704292
for _i = length(eargs):-1:i0
42714293
if isccall && _i == 3
42724294
i = 1
@@ -4289,40 +4311,33 @@ function inlining_pass(e::Expr, sv::InferenceState)
42894311
else
42904312
argloc = eargs
42914313
end
4292-
res = inlining_pass(ei, sv)
4293-
res1 = res[1]
4294-
res2 = res[2]
4295-
has_new_stmts = isa(res2, Array) && !isempty(res2::Array{Any,1})
4314+
sl0 = length(stmts)
4315+
res = inlining_pass(ei, sv, stmts, ins)
4316+
ns = length(stmts) - sl0 # number of new statements just added
42964317
if isccallee
4297-
restype = exprtype(res1, sv.src, sv.mod)
4318+
restype = exprtype(res, sv.src, sv.mod)
42984319
if isa(restype, Const)
42994320
argloc[i] = restype.val
4300-
if !effect_free(res1, sv.src, sv.mod, false)
4301-
insert!(stmts, 1, res1)
4302-
end
4303-
if has_new_stmts
4304-
prepend!(stmts, res2::Array{Any,1})
4321+
if !effect_free(res, sv.src, sv.mod, false)
4322+
insert!(stmts, ins+ns, res)
43054323
end
43064324
# Assume this is the last argument to process
43074325
break
43084326
end
43094327
end
4310-
if has_stmts && !effect_free(res1, sv.src, sv.mod, false)
4311-
restype = exprtype(res1, sv.src, sv.mod)
4328+
if has_stmts && !effect_free(res, sv.src, sv.mod, false)
4329+
restype = exprtype(res, sv.src, sv.mod)
43124330
vnew = newvar!(sv, restype)
43134331
argloc[i] = vnew
4314-
unshift!(stmts, Expr(:(=), vnew, res1))
4332+
insert!(stmts, ins+ns, Expr(:(=), vnew, res))
43154333
else
4316-
argloc[i] = res1
4317-
end
4318-
if has_new_stmts
4319-
res2 = res2::Array{Any,1}
4320-
prepend!(stmts, res2)
4321-
if !has_stmts && !(_i == i0)
4322-
for stmt in res2
4323-
if !effect_free(stmt, sv.src, sv.mod, true)
4324-
has_stmts = true
4325-
end
4334+
argloc[i] = res
4335+
end
4336+
if !has_stmts && ns > 0 && !(_i == i0)
4337+
for s = ins:ins+ns-1
4338+
stmt = stmts[s]
4339+
if !effect_free(stmt, sv.src, sv.mod, true)
4340+
has_stmts = true; break
43264341
end
43274342
end
43284343
end
@@ -4337,7 +4352,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
43374352
end
43384353
end
43394354
if e.head !== :call
4340-
return (e, stmts)
4355+
return e
43414356
end
43424357

43434358
ft = exprtype(arg1, sv.src, sv.mod)
@@ -4349,10 +4364,12 @@ function inlining_pass(e::Expr, sv::InferenceState)
43494364
else
43504365
f = nothing
43514366
if !( isleaftype(ft) || ft<:Type )
4352-
return (e, stmts)
4367+
return e
43534368
end
43544369
end
43554370

4371+
ins += (length(stmts) - prev_stmts_length)
4372+
43564373
if sv.params.inlining
43574374
if isdefined(Main, :Base) &&
43584375
((isdefined(Main.Base, :^) && f === Main.Base.:^) ||
@@ -4376,19 +4393,13 @@ function inlining_pass(e::Expr, sv::InferenceState)
43764393
exprtype(a1, sv.src, sv.mod) basenumtype)
43774394
if square
43784395
e.args = Any[GlobalRef(Main.Base,:*), a1, a1]
4379-
res = inlining_pass(e, sv)
4396+
res = inlining_pass(e, sv, stmts, ins)
43804397
else
43814398
e.args = Any[GlobalRef(Main.Base,:*), Expr(:call, GlobalRef(Main.Base,:*), a1, a1), a1]
43824399
e.args[2].typ = e.typ
4383-
res = inlining_pass(e, sv)
4384-
end
4385-
if isa(res, Tuple)
4386-
if isa(res[2], Array) && !isempty(res[2])
4387-
append!(stmts, res[2])
4388-
end
4389-
res = res[1]
4400+
res = inlining_pass(e, sv, stmts, ins)
43904401
end
4391-
return (res, stmts)
4402+
return res
43924403
end
43934404
end
43944405
end
@@ -4399,13 +4410,14 @@ function inlining_pass(e::Expr, sv::InferenceState)
43994410
ata[1] = ft
44004411
for i = 2:length(e.args)
44014412
a = exprtype(e.args[i], sv.src, sv.mod)
4402-
(a === Bottom || isvarargtype(a)) && return (e, stmts)
4413+
(a === Bottom || isvarargtype(a)) && return e
44034414
ata[i] = a
44044415
end
4405-
res = inlineable(f, ft, e, ata, sv)
4416+
res = inlineable(f, ft, e, ata, sv, stmts)
44064417
if isa(res,Tuple)
44074418
if isa(res[2],Array) && !isempty(res[2])
4408-
append!(stmts,res[2])
4419+
splice!(stmts, ins:ins-1, res[2])
4420+
ins += length(res[2])
44094421
end
44104422
res = res[1]
44114423
end
@@ -4417,7 +4429,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
44174429
e = res::Expr
44184430
f = _apply; ft = abstract_eval_constant(f)
44194431
else
4420-
return (res,stmts)
4432+
return res
44214433
end
44224434
end
44234435

@@ -4439,7 +4451,7 @@ function inlining_pass(e::Expr, sv::InferenceState)
44394451
newargs[i-2] = Any[ mk_getfield(aarg,j,tp[j]) for j=1:length(tp) ]
44404452
else
44414453
# not all args expandable
4442-
return (e,stmts)
4454+
return e
44434455
end
44444456
end
44454457
e.args = [Any[e.args[2]]; newargs...]
@@ -4454,14 +4466,14 @@ function inlining_pass(e::Expr, sv::InferenceState)
44544466
else
44554467
f = nothing
44564468
if !( isleaftype(ft) || ft<:Type )
4457-
return (e,stmts)
4469+
return e
44584470
end
44594471
end
44604472
else
4461-
return (e,stmts)
4473+
return e
44624474
end
44634475
end
4464-
return (e,stmts)
4476+
return e
44654477
end
44664478

44674479
const compiler_temp_sym = Symbol("#temp#")
@@ -4562,7 +4574,8 @@ normslot(s::TypedSlot) = SlotNumber(slot_id(s))
45624574
function get_replacement(table, var::Union{SlotNumber, SSAValue}, init::ANY, nargs, slottypes, ssavaluetypes)
45634575
#if isa(init, QuoteNode) # this can cause slight code size increases
45644576
# return init
4565-
if isa(init, Expr) && init.head === :static_parameter
4577+
if (isa(init, Expr) && init.head === :static_parameter) || isa(init, corenumtype) ||
4578+
init === () || init === nothing
45664579
return init
45674580
elseif isa(init, Slot) && is_argument(nargs, init::Slot)
45684581
# the transformation is not ideal if the assignment

0 commit comments

Comments
 (0)