Skip to content

Commit 13e854f

Browse files
committed
optimize construction of InferenceResult for constant inference
1 parent fabc166 commit 13e854f

File tree

8 files changed

+81
-115
lines changed

8 files changed

+81
-115
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,7 @@ const_prop_result(inf_result::InferenceResult) =
12391239
ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
12401240
inf_result.ipo_effects, inf_result.linfo)
12411241

1242-
# return cached constant analysis result
1242+
# return cached result of constant analysis
12431243
return_cached_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) =
12441244
const_prop_result(inf_result)
12451245

@@ -1249,7 +1249,14 @@ function const_prop_call(interp::AbstractInterpreter,
12491249
inf_cache = get_inference_cache(interp)
12501250
𝕃ᵢ = typeinf_lattice(interp)
12511251
argtypes = has_conditional(𝕃ᵢ, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes)
1252-
given_argtypes, overridden_by_const = matching_cache_argtypes(𝕃ᵢ, mi, argtypes)
1252+
# use `cache_argtypes` that has been constructed for fresh regular inference if available
1253+
volatile_inf_result = result.volatile_inf_result
1254+
if volatile_inf_result !== nothing
1255+
cache_argtypes = volatile_inf_result.inf_result.argtypes
1256+
else
1257+
cache_argtypes = matching_cache_argtypes(𝕃ᵢ, mi)
1258+
end
1259+
given_argtypes = matching_cache_argtypes(𝕃ᵢ, mi, argtypes, cache_argtypes)
12531260
inf_result = cache_lookup(𝕃ᵢ, mi, given_argtypes, inf_cache)
12541261
if inf_result !== nothing
12551262
# found the cache for this constant prop'
@@ -1260,12 +1267,18 @@ function const_prop_call(interp::AbstractInterpreter,
12601267
@assert inf_result.linfo === mi "MethodInstance for cached inference result does not match"
12611268
return return_cached_result(interp, inf_result, sv)
12621269
end
1263-
# perform fresh constant prop'
1264-
inf_result = InferenceResult(mi, given_argtypes, overridden_by_const)
1265-
if !any(inf_result.overridden_by_const)
1270+
overridden_by_const = falses(length(given_argtypes))
1271+
for i = 1:length(given_argtypes)
1272+
if given_argtypes[i] !== cache_argtypes[i]
1273+
overridden_by_const[i] = true
1274+
end
1275+
end
1276+
if !any(overridden_by_const)
12661277
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
12671278
return nothing
12681279
end
1280+
# perform fresh constant prop'
1281+
inf_result = InferenceResult(mi, given_argtypes, overridden_by_const)
12691282
frame = InferenceState(inf_result, #=cache_mode=#:local, interp)
12701283
if frame === nothing
12711284
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
@@ -1287,26 +1300,19 @@ end
12871300

12881301
# TODO implement MustAlias forwarding
12891302

1290-
struct ConditionalArgtypes <: ForwardableArgtypes
1303+
struct ConditionalArgtypes
12911304
arginfo::ArgInfo
12921305
sv::InferenceState
12931306
end
12941307

1295-
"""
1296-
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
1297-
conditional_argtypes::ConditionalArgtypes)
1298-
1299-
The implementation is able to forward `Conditional` of `conditional_argtypes`,
1300-
as well as the other general extended lattice information.
1301-
"""
13021308
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
1303-
conditional_argtypes::ConditionalArgtypes)
1309+
conditional_argtypes::ConditionalArgtypes,
1310+
cache_argtypes::Vector{Any})
13041311
(; arginfo, sv) = conditional_argtypes
13051312
(; fargs, argtypes) = arginfo
13061313
given_argtypes = Vector{Any}(undef, length(argtypes))
13071314
def = mi.def::Method
13081315
nargs = Int(def.nargs)
1309-
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
13101316
local condargs = nothing
13111317
for i in 1:length(argtypes)
13121318
argtype = argtypes[i]
@@ -1349,7 +1355,7 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
13491355
else
13501356
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
13511357
end
1352-
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
1358+
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
13531359
end
13541360

13551361
# This is only for use with `Conditional`.

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,14 @@ include("compiler/ssair/ir.jl")
203203
include("compiler/ssair/tarjan.jl")
204204

205205
include("compiler/abstractlattice.jl")
206+
include("compiler/stmtinfo.jl")
206207
include("compiler/inferenceresult.jl")
207208
include("compiler/inferencestate.jl")
208209

209210
include("compiler/typeutils.jl")
210211
include("compiler/typelimits.jl")
211212
include("compiler/typelattice.jl")
212213
include("compiler/tfuncs.jl")
213-
include("compiler/stmtinfo.jl")
214214

215215
include("compiler/abstractinterpretation.jl")
216216
include("compiler/typeinfer.jl")

base/compiler/inferenceresult.jl

Lines changed: 33 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,30 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
"""
4-
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance) ->
5-
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
6-
7-
Returns argument types `cache_argtypes::Vector{Any}` for `mi` that are in the native
8-
Julia type domain. `overridden_by_const::BitVector` is all `false` meaning that
9-
there is no additional extended lattice information there.
10-
11-
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::ForwardableArgtypes) ->
12-
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
13-
14-
Returns cache-correct extended lattice argument types `cache_argtypes::Vector{Any}`
15-
for `mi` given some `argtypes` accompanied by `overridden_by_const::BitVector`
16-
that marks which argument contains additional extended lattice information.
17-
18-
In theory, there could be a `cache` containing a matching `InferenceResult`
19-
for the provided `mi` and `given_argtypes`. The purpose of this function is
20-
to return a valid value for `cache_lookup(𝕃, mi, argtypes, cache).argtypes`,
21-
so that we can construct cache-correct `InferenceResult`s in the first place.
22-
"""
23-
function matching_cache_argtypes end
24-
253
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance)
26-
method = isa(mi.def, Method) ? mi.def::Method : nothing
27-
cache_argtypes = most_general_argtypes(method, mi.specTypes)
28-
overridden_by_const = falses(length(cache_argtypes))
29-
return cache_argtypes, overridden_by_const
4+
(; def, specTypes) = mi
5+
return most_general_argtypes(isa(def, Method) ? def : nothing, specTypes)
306
end
317

32-
struct SimpleArgtypes <: ForwardableArgtypes
8+
struct SimpleArgtypes
339
argtypes::Vector{Any}
3410
end
3511

36-
"""
37-
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::SimpleArgtypes)
38-
39-
The implementation for `argtypes` with general extended lattice information.
40-
This is supposed to be used for debugging and testing or external `AbstractInterpreter`
41-
usages and in general `matching_cache_argtypes(::MethodInstance, ::ConditionalArgtypes)`
42-
is more preferred it can forward `Conditional` information.
43-
"""
44-
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, simple_argtypes::SimpleArgtypes)
12+
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
13+
simple_argtypes::SimpleArgtypes,
14+
cache_argtypes::Vector{Any})
4515
(; argtypes) = simple_argtypes
4616
given_argtypes = Vector{Any}(undef, length(argtypes))
4717
for i = 1:length(argtypes)
4818
given_argtypes[i] = widenslotwrapper(argtypes[i])
4919
end
5020
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
51-
return pick_const_args(𝕃, mi, given_argtypes)
21+
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
5222
end
5323

54-
function pick_const_args(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any})
55-
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
56-
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
57-
end
58-
59-
function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, overridden_by_const::BitVector, given_argtypes::Vector{Any})
60-
for i = 1:length(given_argtypes)
24+
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
25+
nargtypes = length(given_argtypes)
26+
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
27+
for i = 1:nargtypes
6128
given_argtype = given_argtypes[i]
6229
cache_argtype = cache_argtypes[i]
6330
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
@@ -66,13 +33,13 @@ function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, ov
6633
!(𝕃, given_argtype, cache_argtype))
6734
# if the type information of this `PartialStruct` is less strict than
6835
# declared method signature, narrow it down using `tmeet`
69-
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
36+
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
7037
end
71-
cache_argtypes[i] = given_argtype
72-
overridden_by_const[i] = true
38+
else
39+
given_argtypes[i] = cache_argtype
7340
end
7441
end
75-
return cache_argtypes, overridden_by_const
42+
return given_argtypes
7643
end
7744

7845
function is_argtype_match(𝕃::AbstractLattice,
@@ -89,9 +56,9 @@ end
8956
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
9057
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
9158
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
92-
def = mi.def
93-
isva = isa(def, Method) ? def.isva : false
94-
nargs = isa(def, Method) ? Int(def.nargs) : length(mi.specTypes.parameters)
59+
def = mi.def::Method
60+
isva = def.isva
61+
nargs = Int(def.nargs)
9562
if isva || isvarargtype(given_argtypes[end])
9663
isva_given_argtypes = Vector{Any}(undef, nargs)
9764
for i = 1:(nargs-isva)
@@ -112,14 +79,11 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
11279
return given_argtypes
11380
end
11481

115-
function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes),
116-
withfirst::Bool = true)
82+
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
11783
toplevel = method === nothing
11884
isva = !toplevel && method.isva
11985
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
12086
nargs::Int = toplevel ? 0 : method.nargs
121-
# For opaque closure, the closure environment is processed elsewhere
122-
withfirst || (nargs -= 1)
12387
cache_argtypes = Vector{Any}(undef, nargs)
12488
# First, if we're dealing with a varargs method, then we set the last element of `args`
12589
# to the appropriate `Tuple` type or `PartialStruct` instance.
@@ -162,17 +126,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
162126
cache_argtypes[nargs] = vargtype
163127
nargs -= 1
164128
end
165-
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
129+
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
166130
# type info as we go (where possible). Note that if we're dealing with a varargs method,
167131
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
168132
# we don't overwrite the result of that work here).
169133
if mi_argtypes_length > 0
170-
n = mi_argtypes_length > nargs ? nargs : mi_argtypes_length
171-
tail_index = n
134+
tail_index = nargtypes = min(mi_argtypes_length, nargs)
172135
local lastatype
173-
for i = 1:n
136+
for i = 1:nargtypes
174137
atyp = mi_argtypes[i]
175-
if i == n && isvarargtype(atyp)
138+
if i == nargtypes && isvarargtype(atyp)
176139
atyp = unwrapva(atyp)
177140
tail_index -= 1
178141
end
@@ -185,16 +148,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
185148
else
186149
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
187150
end
188-
i == n && (lastatype = atyp)
151+
i == nargtypes && (lastatype = atyp)
189152
cache_argtypes[i] = atyp
190153
end
191-
for i = (tail_index + 1):nargs
154+
for i = (tail_index+1):nargs
192155
cache_argtypes[i] = lastatype
193156
end
194157
else
195158
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
196159
end
197-
cache_argtypes
160+
return cache_argtypes
198161
end
199162

200163
# eliminate free `TypeVar`s in order to make the life much easier down the road:
@@ -213,22 +176,15 @@ end
213176
function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any},
214177
cache::Vector{InferenceResult})
215178
method = mi.def::Method
216-
nargs = Int(method.nargs)
217-
method.isva && (nargs -= 1)
218-
length(given_argtypes) nargs || return nothing
179+
nargtypes = length(given_argtypes)
180+
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
219181
for cached_result in cache
220-
cached_result.linfo === mi || continue
182+
cached_result.linfo === mi || @goto next_cache
221183
cache_argtypes = cached_result.argtypes
222-
cache_overridden_by_const = cached_result.overridden_by_const
223-
for i in 1:nargs
224-
if !is_argtype_match(𝕃, widenmustalias(given_argtypes[i]),
225-
cache_argtypes[i], cache_overridden_by_const[i])
226-
@goto next_cache
227-
end
228-
end
229-
if method.isva
230-
if !is_argtype_match(𝕃, tuple_tfunc(𝕃, given_argtypes[(nargs + 1):end]),
231-
cache_argtypes[end], cache_overridden_by_const[end])
184+
@assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`"
185+
cache_overridden_by_const = cached_result.overridden_by_const::BitVector
186+
for i in 1:nargtypes
187+
if !is_argtype_match(𝕃, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i])
232188
@goto next_cache
233189
end
234190
end

base/compiler/inferencestate.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,10 @@ end
832832
frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
833833
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
834834

835-
is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
835+
function is_constproped(sv::InferenceState)
836+
(;overridden_by_const) = sv.result
837+
return overridden_by_const !== nothing
838+
end
836839
is_constproped(::IRInterpretationState) = true
837840

838841
is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ the original `ci::CodeInfo` are modified.
1010
"""
1111
function inflate_ir!(ci::CodeInfo, mi::MethodInstance)
1212
sptypes = sptypes_from_meth_instance(mi)
13-
argtypes, _ = matching_cache_argtypes(fallback_lattice, mi)
13+
argtypes = matching_cache_argtypes(fallback_lattice, mi)
1414
return inflate_ir!(ci, sptypes, argtypes)
1515
end
1616
function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any})

base/compiler/typeinfer.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ struct EdgeCallResult
810810
end
811811
end
812812

813-
# return cached regular inference result
813+
# return cached result of regular inference
814814
function return_cached_result(::AbstractInterpreter, codeinst::CodeInstance, caller::AbsIntState)
815815
rt = cached_return_type(codeinst)
816816
effects = ipo_effects(codeinst)
@@ -869,10 +869,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
869869
effects = isinferred ? frame.result.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects
870870
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
871871
# propagate newly inferred source to the inliner, allowing efficient inlining w/o deserialization:
872-
# note that this result is cached globally exclusively, we can use this local result destructively
873-
volatile_inf_result = (isinferred && (force_inline ||
874-
src_inlining_policy(interp, result.src, NoCallInfo(), IR_FLAG_NULL))) ?
875-
VolatileInferenceResult(result) : nothing
872+
# note that this result is cached globally exclusively, so we can use this local result destructively
873+
volatile_inf_result = isinferred ? VolatileInferenceResult(result) : nothing
876874
return EdgeCallResult(frame.bestguess, exc_bestguess, edge, effects, volatile_inf_result)
877875
elseif frame === true
878876
# unresolvable cycle

base/compiler/types.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ struct VarState
5757
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
5858
end
5959

60-
abstract type ForwardableArgtypes end
61-
6260
struct AnalysisResults
6361
result
6462
next::AnalysisResults
@@ -70,16 +68,19 @@ end
7068
const NULL_ANALYSIS_RESULTS = AnalysisResults(nothing)
7169

7270
"""
73-
InferenceResult(mi::MethodInstance, [argtypes::ForwardableArgtypes, 𝕃::AbstractLattice])
71+
result::InferenceResult
7472
7573
A type that represents the result of running type inference on a chunk of code.
76-
77-
See also [`matching_cache_argtypes`](@ref).
74+
There are two constructor available:
75+
- `InferenceResult(mi::MethodInstance, [𝕃::AbstractLattice])` for regular inference,
76+
without extended lattice information included in `result.argtypes`.
77+
- `InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::BitVector)`
78+
for constant inference, with extended lattice information included in `result.argtypes`.
7879
"""
7980
mutable struct InferenceResult
8081
const linfo::MethodInstance
8182
const argtypes::Vector{Any}
82-
const overridden_by_const::BitVector
83+
const overridden_by_const::Union{Nothing,BitVector}
8384
result # extended lattice element if inferred, nothing otherwise
8485
exc_result # like `result`, but for the thrown value
8586
src # ::Union{CodeInfo, IRCode, OptimizationState} if inferred copy is available, nothing otherwise
@@ -89,16 +90,18 @@ mutable struct InferenceResult
8990
analysis_results::AnalysisResults # AnalysisResults with e.g. result::ArgEscapeCache if optimized, otherwise NULL_ANALYSIS_RESULTS
9091
is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively
9192
ci::CodeInstance # CodeInstance if this result has been added to the cache
92-
function InferenceResult(mi::MethodInstance, cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
93-
# def = mi.def
94-
# nargs = def isa Method ? Int(def.nargs) : 0
95-
# @assert length(cache_argtypes) == nargs
96-
return new(mi, cache_argtypes, overridden_by_const, nothing, nothing, nothing,
93+
function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector})
94+
def = mi.def
95+
nargs = def isa Method ? Int(def.nargs) : 0
96+
@assert length(argtypes) == nargs "invalid `argtypes` for `mi`"
97+
return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing,
9798
WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false)
9899
end
99100
end
100-
InferenceResult(mi::MethodInstance, 𝕃::AbstractLattice=fallback_lattice) =
101-
InferenceResult(mi, matching_cache_argtypes(𝕃, mi)...)
101+
function InferenceResult(mi::MethodInstance, 𝕃::AbstractLattice=fallback_lattice)
102+
argtypes = matching_cache_argtypes(𝕃, mi)
103+
return InferenceResult(mi, argtypes, #=overridden_by_const=#nothing)
104+
end
102105

103106
function stack_analysis_result!(inf_result::InferenceResult, @nospecialize(result))
104107
return inf_result.analysis_results = AnalysisResults(result, inf_result.analysis_results)

0 commit comments

Comments
 (0)