1+ package dotty .tools
2+ package dotc
3+ package core
4+
5+ import Types .* , Contexts .* , Symbols .* , Constants .* , Decorators .*
6+ import config .Printers .typr
7+ import reporting .trace
8+ import StdNames .tpnme
9+
10+ object TypeEval :
11+
12+ def tryCompiletimeConstantFold (tp : AppliedType )(using Context ): Type = tp.tycon match
13+ case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
14+ extension (tp : Type ) def fixForEvaluation : Type =
15+ tp.normalized.dealias match
16+ // enable operations for constant singleton terms. E.g.:
17+ // ```
18+ // final val one = 1
19+ // type Two = one.type + one.type
20+ // ```
21+ case tp : TypeProxy if tp.underlying.isStable => tp.underlying.fixForEvaluation
22+ case tp => tp
23+
24+ def constValue (tp : Type ): Option [Any ] = tp.fixForEvaluation match
25+ case ConstantType (Constant (n)) => Some (n)
26+ case _ => None
27+
28+ def boolValue (tp : Type ): Option [Boolean ] = tp.fixForEvaluation match
29+ case ConstantType (Constant (n : Boolean )) => Some (n)
30+ case _ => None
31+
32+ def intValue (tp : Type ): Option [Int ] = tp.fixForEvaluation match
33+ case ConstantType (Constant (n : Int )) => Some (n)
34+ case _ => None
35+
36+ def longValue (tp : Type ): Option [Long ] = tp.fixForEvaluation match
37+ case ConstantType (Constant (n : Long )) => Some (n)
38+ case _ => None
39+
40+ def floatValue (tp : Type ): Option [Float ] = tp.fixForEvaluation match
41+ case ConstantType (Constant (n : Float )) => Some (n)
42+ case _ => None
43+
44+ def doubleValue (tp : Type ): Option [Double ] = tp.fixForEvaluation match
45+ case ConstantType (Constant (n : Double )) => Some (n)
46+ case _ => None
47+
48+ def stringValue (tp : Type ): Option [String ] = tp.fixForEvaluation match
49+ case ConstantType (Constant (n : String )) => Some (n)
50+ case _ => None
51+
52+ // Returns Some(true) if the type is a constant.
53+ // Returns Some(false) if the type is not a constant.
54+ // Returns None if there is not enough information to determine if the type is a constant.
55+ // The type is a constant if it is a constant type or a type operation composition of constant types.
56+ // If we get a type reference for an argument, then the result is not yet known.
57+ def isConst (tp : Type ): Option [Boolean ] = tp.dealias match
58+ // known to be constant
59+ case ConstantType (_) => Some (true )
60+ // currently not a concrete known type
61+ case TypeRef (NoPrefix ,_) => None
62+ // currently not a concrete known type
63+ case _ : TypeParamRef => None
64+ // constant if the term is constant
65+ case t : TermRef => isConst(t.underlying)
66+ // an operation type => recursively check all argument compositions
67+ case applied : AppliedType if defn.isCompiletimeAppliedType(applied.typeSymbol) =>
68+ val argsConst = applied.args.map(isConst)
69+ if (argsConst.exists(_.isEmpty)) None
70+ else Some (argsConst.forall(_.get))
71+ // all other types are considered not to be constant
72+ case _ => Some (false )
73+
74+ def expectArgsNum (expectedNum : Int ): Unit =
75+ // We can use assert instead of a compiler type error because this error should not
76+ // occur since the type signature of the operation enforces the proper number of args.
77+ assert(tp.args.length == expectedNum, s " Type operation expects $expectedNum arguments but found ${tp.args.length}" )
78+
79+ def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
80+
81+ // Runs the op and returns the result as a constant type.
82+ // If the op throws an exception, then this exception is converted into a type error.
83+ def runConstantOp (op : => Any ): Type =
84+ val result =
85+ try op
86+ catch case e : Throwable =>
87+ throw new TypeError (e.getMessage.nn)
88+ ConstantType (Constant (result))
89+
90+ def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
91+ expectArgsNum(1 )
92+ extractor(tp.args.head).map(a => runConstantOp(op(a)))
93+
94+ def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
95+ constantFold2AB(extractor, extractor, op)
96+
97+ def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
98+ expectArgsNum(2 )
99+ for
100+ a <- extractorA(tp.args(0 ))
101+ b <- extractorB(tp.args(1 ))
102+ yield runConstantOp(op(a, b))
103+
104+ def constantFold3 [TA , TB , TC ](
105+ extractorA : Type => Option [TA ],
106+ extractorB : Type => Option [TB ],
107+ extractorC : Type => Option [TC ],
108+ op : (TA , TB , TC ) => Any
109+ ): Option [Type ] =
110+ expectArgsNum(3 )
111+ for
112+ a <- extractorA(tp.args(0 ))
113+ b <- extractorB(tp.args(1 ))
114+ c <- extractorC(tp.args(2 ))
115+ yield runConstantOp(op(a, b, c))
116+
117+ trace(i " compiletime constant fold $tp" , typr, show = true ) {
118+ val name = tycon.symbol.name
119+ val owner = tycon.symbol.owner
120+ val constantType =
121+ if defn.isCompiletime_S(tycon.symbol) then
122+ constantFold1(natValue, _ + 1 )
123+ else if owner == defn.CompiletimeOpsAnyModuleClass then name match
124+ case tpnme.Equals => constantFold2(constValue, _ == _)
125+ case tpnme.NotEquals => constantFold2(constValue, _ != _)
126+ case tpnme.ToString => constantFold1(constValue, _.toString)
127+ case tpnme.IsConst => isConst(tp.args.head).map(b => ConstantType (Constant (b)))
128+ case _ => None
129+ else if owner == defn.CompiletimeOpsIntModuleClass then name match
130+ case tpnme.Abs => constantFold1(intValue, _.abs)
131+ case tpnme.Negate => constantFold1(intValue, x => - x)
132+ // ToString is deprecated for ops.int, and moved to ops.any
133+ case tpnme.ToString => constantFold1(intValue, _.toString)
134+ case tpnme.Plus => constantFold2(intValue, _ + _)
135+ case tpnme.Minus => constantFold2(intValue, _ - _)
136+ case tpnme.Times => constantFold2(intValue, _ * _)
137+ case tpnme.Div => constantFold2(intValue, _ / _)
138+ case tpnme.Mod => constantFold2(intValue, _ % _)
139+ case tpnme.Lt => constantFold2(intValue, _ < _)
140+ case tpnme.Gt => constantFold2(intValue, _ > _)
141+ case tpnme.Ge => constantFold2(intValue, _ >= _)
142+ case tpnme.Le => constantFold2(intValue, _ <= _)
143+ case tpnme.Xor => constantFold2(intValue, _ ^ _)
144+ case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
145+ case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
146+ case tpnme.ASR => constantFold2(intValue, _ >> _)
147+ case tpnme.LSL => constantFold2(intValue, _ << _)
148+ case tpnme.LSR => constantFold2(intValue, _ >>> _)
149+ case tpnme.Min => constantFold2(intValue, _ min _)
150+ case tpnme.Max => constantFold2(intValue, _ max _)
151+ case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
152+ case tpnme.ToLong => constantFold1(intValue, _.toLong)
153+ case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
154+ case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
155+ case _ => None
156+ else if owner == defn.CompiletimeOpsLongModuleClass then name match
157+ case tpnme.Abs => constantFold1(longValue, _.abs)
158+ case tpnme.Negate => constantFold1(longValue, x => - x)
159+ case tpnme.Plus => constantFold2(longValue, _ + _)
160+ case tpnme.Minus => constantFold2(longValue, _ - _)
161+ case tpnme.Times => constantFold2(longValue, _ * _)
162+ case tpnme.Div => constantFold2(longValue, _ / _)
163+ case tpnme.Mod => constantFold2(longValue, _ % _)
164+ case tpnme.Lt => constantFold2(longValue, _ < _)
165+ case tpnme.Gt => constantFold2(longValue, _ > _)
166+ case tpnme.Ge => constantFold2(longValue, _ >= _)
167+ case tpnme.Le => constantFold2(longValue, _ <= _)
168+ case tpnme.Xor => constantFold2(longValue, _ ^ _)
169+ case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
170+ case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
171+ case tpnme.ASR => constantFold2(longValue, _ >> _)
172+ case tpnme.LSL => constantFold2(longValue, _ << _)
173+ case tpnme.LSR => constantFold2(longValue, _ >>> _)
174+ case tpnme.Min => constantFold2(longValue, _ min _)
175+ case tpnme.Max => constantFold2(longValue, _ max _)
176+ case tpnme.NumberOfLeadingZeros =>
177+ constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
178+ case tpnme.ToInt => constantFold1(longValue, _.toInt)
179+ case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
180+ case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
181+ case _ => None
182+ else if owner == defn.CompiletimeOpsFloatModuleClass then name match
183+ case tpnme.Abs => constantFold1(floatValue, _.abs)
184+ case tpnme.Negate => constantFold1(floatValue, x => - x)
185+ case tpnme.Plus => constantFold2(floatValue, _ + _)
186+ case tpnme.Minus => constantFold2(floatValue, _ - _)
187+ case tpnme.Times => constantFold2(floatValue, _ * _)
188+ case tpnme.Div => constantFold2(floatValue, _ / _)
189+ case tpnme.Mod => constantFold2(floatValue, _ % _)
190+ case tpnme.Lt => constantFold2(floatValue, _ < _)
191+ case tpnme.Gt => constantFold2(floatValue, _ > _)
192+ case tpnme.Ge => constantFold2(floatValue, _ >= _)
193+ case tpnme.Le => constantFold2(floatValue, _ <= _)
194+ case tpnme.Min => constantFold2(floatValue, _ min _)
195+ case tpnme.Max => constantFold2(floatValue, _ max _)
196+ case tpnme.ToInt => constantFold1(floatValue, _.toInt)
197+ case tpnme.ToLong => constantFold1(floatValue, _.toLong)
198+ case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
199+ case _ => None
200+ else if owner == defn.CompiletimeOpsDoubleModuleClass then name match
201+ case tpnme.Abs => constantFold1(doubleValue, _.abs)
202+ case tpnme.Negate => constantFold1(doubleValue, x => - x)
203+ case tpnme.Plus => constantFold2(doubleValue, _ + _)
204+ case tpnme.Minus => constantFold2(doubleValue, _ - _)
205+ case tpnme.Times => constantFold2(doubleValue, _ * _)
206+ case tpnme.Div => constantFold2(doubleValue, _ / _)
207+ case tpnme.Mod => constantFold2(doubleValue, _ % _)
208+ case tpnme.Lt => constantFold2(doubleValue, _ < _)
209+ case tpnme.Gt => constantFold2(doubleValue, _ > _)
210+ case tpnme.Ge => constantFold2(doubleValue, _ >= _)
211+ case tpnme.Le => constantFold2(doubleValue, _ <= _)
212+ case tpnme.Min => constantFold2(doubleValue, _ min _)
213+ case tpnme.Max => constantFold2(doubleValue, _ max _)
214+ case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
215+ case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
216+ case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
217+ case _ => None
218+ else if owner == defn.CompiletimeOpsStringModuleClass then name match
219+ case tpnme.Plus => constantFold2(stringValue, _ + _)
220+ case tpnme.Length => constantFold1(stringValue, _.length)
221+ case tpnme.Matches => constantFold2(stringValue, _ matches _)
222+ case tpnme.Substring =>
223+ constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
224+ case tpnme.CharAt =>
225+ constantFold2AB(stringValue, intValue, _ charAt _)
226+ case _ => None
227+ else if owner == defn.CompiletimeOpsBooleanModuleClass then name match
228+ case tpnme.Not => constantFold1(boolValue, x => ! x)
229+ case tpnme.And => constantFold2(boolValue, _ && _)
230+ case tpnme.Or => constantFold2(boolValue, _ || _)
231+ case tpnme.Xor => constantFold2(boolValue, _ ^ _)
232+ case _ => None
233+ else None
234+
235+ constantType.getOrElse(NoType )
236+ }
237+
238+ case _ => NoType
239+ end tryCompiletimeConstantFold
240+ end TypeEval
0 commit comments