Skip to content

Commit 2ec47da

Browse files
Kenoaviatesk
authored andcommitted
Very WIP: Refactor core inference loops to use less memory
Currently inference uses `O(<number of statements>*<number of slots>)` state in the core inference loop. This is usually fine, because users don't tend to write functions that are particularly long. However, MTK does generate functions that are excessively long and we've observed MTK models that spend 99% of their inference time just allocating and copying this state. It is possible to get away with significantly smaller state, and this PR is a first step in that direction, reducing the state to `O(<number of basic blocks>*<number of slots>)`. Further improvements are possible by making use of slot liveness information and only storing those slots that are live across a particular basic block. The core change here is to keep a full set of `slottypes` only at basic block boundaries rather than at each statement. For statements in between, the full variable state can be fully recovered by linearly scanning throughout the basic block, taking note of slot assignments (together with the SSA type) and NewVarNodes. The current status of this branch is that the changes appear correct (no known functional regressions) and significantly improve the MTK test cases in question (no exact benchmarks here for now, since the branch still needs a number of fixes before final numbers make sense), but somewhat regress optimizer quality (which is expected and just a missing TODO) and bootstrap time (which is not expected and something I need to dig into).
1 parent 65b9be4 commit 2ec47da

File tree

9 files changed

+472
-303
lines changed

9 files changed

+472
-303
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 312 additions & 172 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ include("compiler/utilities.jl")
128128
include("compiler/validation.jl")
129129
include("compiler/methodtable.jl")
130130

131+
function argextype end # imported by EscapeAnalysis
132+
function stmt_effect_free end # imported by EscapeAnalysis
133+
function alloc_array_ndims end # imported by EscapeAnalysis
134+
function try_compute_field end # imported by EscapeAnalysis
135+
include("compiler/ssair/basicblock.jl")
136+
include("compiler/ssair/domtree.jl")
137+
include("compiler/ssair/ir.jl")
138+
131139
include("compiler/inferenceresult.jl")
132140
include("compiler/inferencestate.jl")
133141

base/compiler/inferencestate.jl

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,22 @@ mutable struct InferenceState
8888
sptypes::Vector{Any}
8989
slottypes::Vector{Any}
9090
src::CodeInfo
91+
cfg::CFG
9192

9293
#= intermediate states for local abstract interpretation =#
94+
currbb::Int
9395
currpc::Int
94-
ip::BitSetBoundedMinPrioritySet # current active instruction pointers
96+
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
97+
was_reached::BitSet
9598
handler_at::Vector{Int} # current exception handler info
9699
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
97-
stmt_types::Vector{Union{Nothing, VarTable}}
100+
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
101+
bb_vartables::Vector{VarTable}
102+
pc_vartable::VarTable
98103
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
99104
stmt_info::Vector{Any}
100105

101-
#= interprocedural intermediate states for abstract interpretation =#
106+
#= intermediate states for interprocedural abstract interpretation =#
102107
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
103108
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
104109
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
@@ -125,36 +130,40 @@ mutable struct InferenceState
125130
interp::AbstractInterpreter
126131

127132
# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
128-
function InferenceState(result::InferenceResult,
129-
src::CodeInfo, cache::Symbol, interp::AbstractInterpreter)
133+
function InferenceState(result::InferenceResult, src::CodeInfo, cache::Symbol,
134+
interp::AbstractInterpreter)
130135
linfo = result.linfo
131136
world = get_world_counter(interp)
132137
def = linfo.def
133138
mod = isa(def, Method) ? def.module : def
134139
sptypes = sptypes_from_meth_instance(linfo)
135-
136140
code = src.code::Vector{Any}
137-
nstmts = length(code)
138-
currpc = 1
139-
ip = BitSetBoundedMinPrioritySet(nstmts)
140-
handler_at = compute_trycatch(code, ip.elems)
141-
push!(ip, 1)
141+
cfg = compute_basic_blocks(code)
142+
143+
currbb = currpc = 1
144+
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
145+
was_reached = BitSet()
146+
handler_at = compute_trycatch(code, BitSet())
142147
nssavalues = src.ssavaluetypes::Int
143148
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
144-
stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
149+
nstmts = length(code)
145150
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
146151
stmt_info = Any[ nothing for i = 1:nstmts ]
147152

148153
nslots = length(src.slotflags)
149154
slottypes = Vector{Any}(undef, nslots)
155+
pc_vartable = VarTable(undef, nslots)
156+
bb_vartable_proto = VarTable(undef, nslots)
150157
argtypes = result.argtypes
151-
nargs = length(argtypes)
152-
stmt_types[1] = stmt_type1 = VarTable(undef, nslots)
158+
nargtypes = length(argtypes)
153159
for i in 1:nslots
154-
argtyp = (i > nargs) ? Bottom : argtypes[i]
155-
stmt_type1[i] = VarState(argtyp, i > nargs)
160+
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
161+
pc_vartable[i] = VarState(argtyp, i > nargtypes)
162+
bb_vartable_proto[i] = VarState(Bottom, i > nargtypes)
156163
slottypes[i] = argtyp
157164
end
165+
bb_vartables = VarTable[i == 1 ? copy(pc_vartable) : copy(bb_vartable_proto)
166+
for i = 1:length(cfg.blocks)]
158167

159168
pclimitations = IdSet{InferenceState}()
160169
limitations = IdSet{InferenceState}()
@@ -183,8 +192,8 @@ mutable struct InferenceState
183192
cached = cache === :global
184193

185194
frame = new(
186-
linfo, world, mod, sptypes, slottypes, src,
187-
currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info,
195+
linfo, world, mod, sptypes, slottypes, src, cfg,
196+
currbb, currpc, ip, was_reached, handler_at, ssavalue_uses, bb_vartables, pc_vartable, stmt_edges, stmt_info,
188197
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
189198
result, valid_worlds, bestguess, ipo_effects,
190199
params, restrict_abstract_call_sites, cached,
@@ -226,6 +235,8 @@ function any_inbounds(code::Vector{Any})
226235
return false
227236
end
228237

238+
was_reached((; was_reached)::InferenceState, pc::Int) = pc in was_reached
239+
229240
function compute_trycatch(code::Vector{Any}, ip::BitSet)
230241
# The goal initially is to record the frame like this for the state at exit:
231242
# 1: (enter 3) # == 0
@@ -422,7 +433,7 @@ end
422433

423434
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)
424435

425-
function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
436+
function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceState)
426437
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
427438
old = ssavaluetypes[ssa_id]
428439
if old === NOT_FOUND || !(new old)
@@ -431,14 +442,19 @@ function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceStat
431442
# guarantee convergence we need to use tmerge here to ensure that is true
432443
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(old, new)
433444
W = frame.ip
434-
s = frame.stmt_types
435445
for r in frame.ssavalue_uses[ssa_id]
436-
if s[r] !== nothing # s[r] === nothing => unreached statement
437-
push!(W, r)
446+
if was_reached(frame, r)
447+
usebb = block_for_inst(frame.cfg, r)
448+
# We're guaranteed to visit the statement if it's in the current
449+
# basic block, since SSA values can only ever appear after their
450+
# def.
451+
if usebb != frame.currbb
452+
push!(W, usebb)
453+
end
438454
end
439455
end
440456
end
441-
nothing
457+
return nothing
442458
end
443459

444460
function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)

base/compiler/optimize.jl

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ mutable struct OptimizationState
8888
linfo::MethodInstance
8989
src::CodeInfo
9090
ir::Union{Nothing, IRCode}
91+
was_reached::Union{Nothing, BitSet}
9192
stmt_info::Vector{Any}
9293
mod::Module
9394
sptypes::Vector{Any} # static parameters
@@ -100,7 +101,7 @@ mutable struct OptimizationState
100101
WorldView(code_cache(interp), frame.world),
101102
interp)
102103
return new(frame.linfo,
103-
frame.src, nothing, frame.stmt_info, frame.mod,
104+
frame.src, nothing, frame.was_reached, frame.stmt_info, frame.mod,
104105
frame.sptypes, frame.slottypes, inlining)
105106
end
106107
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
@@ -128,11 +129,13 @@ mutable struct OptimizationState
128129
WorldView(code_cache(interp), get_world_counter()),
129130
interp)
130131
return new(linfo,
131-
src, nothing, stmt_info, mod,
132+
src, nothing, nothing, stmt_info, mod,
132133
sptypes_from_meth_instance(linfo), slottypes, inlining)
133134
end
134135
end
135136

137+
was_reached((; was_reached)::OptimizationState, pc::Int) = was_reached === nothing || pc in was_reached
138+
136139
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
137140
src = retrieve_code_info(linfo)
138141
src === nothing && return nothing
@@ -572,9 +575,29 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
572575
end
573576
end
574577
end
578+
meta = Any[]
575579
labelmap = coverage ? fill(0, length(code)) : changemap
576580
while idx <= length(code)
577581
codeloc = codelocs[idx]
582+
stmt = code[idx]
583+
if process_meta!(meta, stmt) || !was_reached(sv, oldidx)
584+
if oldidx < length(labelmap)
585+
changemap[oldidx] != 0 && (changemap[oldidx+1] = changemap[oldidx])
586+
if coverage && labelmap[oldidx] != 0
587+
labelmap[oldidx + 1] = labelmap[oldidx]
588+
end
589+
changemap[oldidx] = -1
590+
coverage && (labelmap[oldidx] = -1)
591+
end
592+
# TODO: It would be more efficient to do this in bulk
593+
deleteat!(code, idx)
594+
deleteat!(codelocs, idx)
595+
deleteat!(ssavaluetypes, idx)
596+
deleteat!(stmtinfo, idx)
597+
deleteat!(ssaflags, idx)
598+
oldidx += 1
599+
continue
600+
end
578601
if coverage && codeloc != prevloc && codeloc != 0
579602
# insert a side-effect instruction before the current instruction in the same basic block
580603
insert!(code, idx, Expr(:code_coverage_effect))
@@ -589,7 +612,13 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
589612
idx += 1
590613
prevloc = codeloc
591614
end
592-
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
615+
if isa(stmt, GotoIfNot)
616+
if !was_reached(sv, oldidx + 1)
617+
code[idx] = GotoNode(stmt.dest)
618+
elseif !was_reached(sv, stmt.dest)
619+
code[idx] = nothing
620+
end
621+
elseif stmt isa Expr && ssavaluetypes[idx] === Union{}
593622
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
594623
# insert unreachable in the same basic block after the current instruction (splitting it)
595624
insert!(code, idx + 1, ReturnNode())
@@ -607,12 +636,9 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
607636
idx += 1
608637
oldidx += 1
609638
end
639+
610640
renumber_ir_elements!(code, changemap, labelmap)
611641

612-
meta = Any[]
613-
for i = 1:length(code)
614-
code[i] = remove_meta!(code[i], meta)
615-
end
616642
strip_trailing_junk!(ci, code, stmtinfo)
617643
cfg = compute_basic_blocks(code)
618644
types = Any[]
@@ -623,18 +649,13 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
623649
return ir
624650
end
625651

626-
function remove_meta!(@nospecialize(stmt), meta::Vector{Any})
627-
if isa(stmt, Expr)
628-
head = stmt.head
629-
if head === :meta
630-
args = stmt.args
631-
if length(args) > 0
632-
push!(meta, stmt)
633-
end
634-
return nothing
635-
end
652+
function process_meta!(meta::Vector{Any}, @nospecialize stmt)
653+
isa(stmt, Expr) || return false
654+
stmt.head === :meta || return false
655+
if length(stmt.args) > 0
656+
push!(meta, stmt)
636657
end
637-
return stmt
658+
return true
638659
end
639660

640661
function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
@@ -796,7 +817,9 @@ end
796817

797818
function cumsum_ssamap!(ssamap::Vector{Int})
798819
rel_change = 0
820+
any_change = false
799821
for i = 1:length(ssamap)
822+
any_change = any_change || ssamap[i] != 0
800823
rel_change += ssamap[i]
801824
if ssamap[i] == -1
802825
# Keep a marker that this statement was deleted
@@ -805,16 +828,15 @@ function cumsum_ssamap!(ssamap::Vector{Int})
805828
ssamap[i] = rel_change
806829
end
807830
end
831+
return any_change
808832
end
809833

810834
function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, labelchangemap::Vector{Int})
811-
cumsum_ssamap!(labelchangemap)
835+
any_change = cumsum_ssamap!(labelchangemap)
812836
if ssachangemap !== labelchangemap
813-
cumsum_ssamap!(ssachangemap)
814-
end
815-
if labelchangemap[end] == 0 && ssachangemap[end] == 0
816-
return
837+
any_change = cumsum_ssamap!(ssachangemap)
817838
end
839+
any_change || return
818840
for i = 1:length(body)
819841
el = body[i]
820842
if isa(el, GotoNode)
@@ -824,7 +846,8 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
824846
if isa(cond, SSAValue)
825847
cond = SSAValue(cond.id + ssachangemap[cond.id])
826848
end
827-
body[i] = GotoIfNot(cond, el.dest + labelchangemap[el.dest])
849+
was_deleted = labelchangemap[el.dest] == typemin(Int)
850+
body[i] = was_deleted ? cond : GotoIfNot(cond, el.dest + labelchangemap[el.dest])
828851
elseif isa(el, ReturnNode)
829852
if isdefined(el, :val)
830853
val = el.val

base/compiler/ssair/driver.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@ else
88
end
99
end
1010

11-
function argextype end # imported by EscapeAnalysis
12-
function stmt_effect_free end # imported by EscapeAnalysis
13-
function alloc_array_ndims end # imported by EscapeAnalysis
14-
function try_compute_field end # imported by EscapeAnalysis
15-
16-
include("compiler/ssair/basicblock.jl")
17-
include("compiler/ssair/domtree.jl")
18-
include("compiler/ssair/ir.jl")
1911
include("compiler/ssair/slot2ssa.jl")
2012
include("compiler/ssair/inlining.jl")
2113
include("compiler/ssair/verify.jl")

base/compiler/ssair/ir.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
736736
elseif xinfo !== nothing
737737
return !xinfo.attach_after
738738
else
739-
return yinfo.attach_after
739+
return (yinfo::NewNodeInfo).attach_after
740740
end
741741
end
742742
return x′.id < y′.id

0 commit comments

Comments
 (0)