@@ -21,7 +21,7 @@ import CheckRealizable._
2121import Variances .{Variance , setStructuralVariances , Invariant }
2222import typer .Nullables
2323import util .Stats ._
24- import util .SimpleIdentitySet
24+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
2525import ast .tpd ._
2626import ast .TreeTypeMap
2727import printing .Texts ._
@@ -1741,7 +1741,7 @@ object Types {
17411741 t
17421742 case t if defn.isErasedFunctionType(t) =>
17431743 t
1744- case t @ SAMType (_) =>
1744+ case t @ SAMType (_, _ ) =>
17451745 t
17461746 case _ =>
17471747 NoType
@@ -5497,104 +5497,119 @@ object Types {
54975497 * A type is a SAM type if it is a reference to a class or trait, which
54985498 *
54995499 * - has a single abstract method with a method type (ExprType
5500- * and PolyType not allowed!) whose result type is not an implicit function type
5501- * and which is not marked inline.
5500+ * and PolyType not allowed!) according to `possibleSamMethods`.
55025501 * - can be instantiated without arguments or with just () as argument.
55035502 *
5504- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5505- * type of the single abstract method.
5503+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5504+ * type of the single abstract method and `samParent` is a subtype of the matched
5505+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5506+ * type.
55065507 */
55075508 object SAMType {
5508- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5509+ /** If possible, return a type which is both a subtype of `origTp` and a type
5510+ * application of `samClass` where none of the type arguments are
5511+ * wildcards (thus making it a valid parent type), otherwise return
5512+ * NoType.
5513+ *
5514+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5515+ * that maximizes the number of possible implementations of `samMeth`. For example,
5516+ * java.util.function defines an interface equivalent to:
5517+ *
5518+ * trait Function[T, R]:
5519+ * def apply(t: T): R
5520+ *
5521+ * and it usually appears with wildcards to compensate for the lack of
5522+ * definition-site variance in Java:
5523+ *
5524+ * (x => x.toInt): Function[? >: String, ? <: Int]
5525+ *
5526+ * When typechecking this lambda, we need to approximate the wildcards to find
5527+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5528+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5529+ * minimizing the first parameter and maximizing the second, we maximize the
5530+ * number of valid implementations of `apply` which lets us implement the lambda
5531+ * with a closure equivalent to:
5532+ *
5533+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5534+ *
5535+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5536+ * we arbitrarily pick the upper-bound.
5537+ */
5538+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5539+ val tp = origTp.baseType(samClass)
5540+ if ! (tp <:< origTp) then NoType
5541+ else tp match
5542+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5543+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5544+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5545+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5546+ vmap.recordLocalVariance(tp.symbol, variance)
5547+ case _ =>
5548+ foldOver(vmap, t)
5549+ val vmap = accu(VarianceMap .empty, samMeth.info)
5550+ val tparams = tycon.typeParamSymbols
5551+ val args1 = args.zipWithConserve(tparams):
5552+ case (arg @ TypeBounds (lo, hi), tparam) =>
5553+ val v = vmap.computedVariance(tparam)
5554+ if v.uncheckedNN < 0 then lo
5555+ else hi
5556+ case (arg, _) => arg
5557+ tp.derivedAppliedType(tycon, args1)
5558+ case _ =>
5559+ tp
5560+
5561+ def samClass (tp : Type )(using Context ): Symbol = tp match
55095562 case tp : ClassInfo =>
5510- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5563+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
55115564 case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
55125565 case et : ExprType => true
55135566 case _ => false
5514- }
5515- // `ContextFunctionN` does not have constructors
5516- val ctor = tp.cls.primaryConstructor
5517- if (! ctor.exists || zeroParams(ctor.info)) tp
5518- else NoType
5567+ val cls = tp.cls
5568+ val validCtor =
5569+ val ctor = cls.primaryConstructor
5570+ // `ContextFunctionN` does not have constructors
5571+ ! ctor.exists || zeroParams(ctor.info)
5572+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5573+ if validCtor && isInstantiable then tp.cls
5574+ else NoSymbol
55195575 case tp : AppliedType =>
5520- zeroParamClass (tp.superType)
5576+ samClass (tp.superType)
55215577 case tp : TypeRef =>
5522- zeroParamClass (tp.underlying)
5578+ samClass (tp.underlying)
55235579 case tp : RefinedType =>
5524- zeroParamClass (tp.underlying)
5580+ samClass (tp.underlying)
55255581 case tp : TypeBounds =>
5526- zeroParamClass (tp.underlying)
5582+ samClass (tp.underlying)
55275583 case tp : TypeVar =>
5528- zeroParamClass (tp.underlying)
5584+ samClass (tp.underlying)
55295585 case tp : AnnotatedType =>
5530- zeroParamClass(tp.underlying)
5531- case _ =>
5532- NoType
5533- }
5534- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5535- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5536- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5537- tp <:< selfType
5586+ samClass(tp.underlying)
55385587 case _ =>
5539- false
5540- }
5541- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5542- if (isInstantiatable(tp)) {
5543- val absMems = tp.possibleSamMethods
5544- if (absMems.size == 1 )
5545- absMems.head.info match {
5546- case mt : MethodType if ! mt.isParamDependent &&
5547- mt.resultType.isValueTypeOrWildcard =>
5548- val cls = tp.classSymbol
5549-
5550- // Given a SAM type such as:
5551- //
5552- // import java.util.function.Function
5553- // Function[? >: String, ? <: Int]
5554- //
5555- // the single abstract method will have type:
5556- //
5557- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5558- //
5559- // which is not implementable outside of the scope of Function.
5560- //
5561- // To avoid this kind of issue, we approximate references to
5562- // parameters of the SAM type by their bounds, this way in the
5563- // above example we get:
5564- //
5565- // (x: String): Int
5566- val approxParams = new ApproximatingTypeMap {
5567- def apply (tp : Type ): Type = tp match {
5568- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5569- tp.info match {
5570- case info : AliasingBounds =>
5571- mapOver(info.alias)
5572- case TypeBounds (lo, hi) =>
5573- range(atVariance(- variance)(apply(lo)), apply(hi))
5574- case _ =>
5575- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5576- }
5577- case _ =>
5578- mapOver(tp)
5579- }
5580- }
5581- val approx =
5582- if ctx.owner.isContainedIn(cls) then mt
5583- else approxParams(mt).asInstanceOf [MethodType ]
5584- Some (approx)
5588+ NoSymbol
5589+
5590+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5591+ val cls = samClass(tp)
5592+ if cls.exists then
5593+ val absMems =
5594+ if tp.isRef(defn.PartialFunctionClass ) then
5595+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5596+ // pretending it is a SAM type. In the future it would be better to merge
5597+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5598+ // def isDefinedAt(x: T) = true
5599+ // and overwrite that method whenever the function body is a sequence of
5600+ // case clauses.
5601+ List (defn.PartialFunction_apply )
5602+ else
5603+ tp.possibleSamMethods.map(_.symbol)
5604+ if absMems.lengthCompare(1 ) == 0 then
5605+ val samMethSym = absMems.head
5606+ val parent = samParent(tp, cls, samMethSym)
5607+ samMethSym.asSeenFrom(parent).info match
5608+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5609+ Some (mt, parent)
55855610 case _ =>
55865611 None
5587- }
5588- else if (tp isRef defn.PartialFunctionClass )
5589- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5590- // pretending it is a SAM type. In the future it would be better to merge
5591- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5592- // def isDefinedAt(x: T) = true
5593- // and overwrite that method whenever the function body is a sequence of
5594- // case clauses.
5595- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
55965612 else None
5597- }
55985613 else None
55995614 }
56005615
@@ -6427,6 +6442,37 @@ object Types {
64276442 }
64286443 }
64296444
6445+ object VarianceMap :
6446+ /** An immutable map representing the variance of keys of type `K` */
6447+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6448+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6449+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6450+ /** The backing map used to implement this VarianceMap. */
6451+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6452+
6453+ /** Return a new map taking into account that K appears in a
6454+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6455+ */
6456+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6457+ val previousVariance = vmap(k)
6458+ if previousVariance == null then
6459+ vmap.updated(k, localVariance)
6460+ else if previousVariance == localVariance || previousVariance == 0 then
6461+ vmap
6462+ else
6463+ vmap.updated(k, 0 )
6464+
6465+ /** Return the variance of `k`:
6466+ * - A positive value means that `k` appears only covariantly.
6467+ * - A negative value means that `k` appears only contravariantly.
6468+ * - A zero value means that `k` appears both covariantly and
6469+ * contravariantly, or appears invariantly.
6470+ * - A null value means that `k` does not appear at all.
6471+ */
6472+ def computedVariance (k : K ): Integer | Null =
6473+ vmap(k)
6474+ export VarianceMap .VarianceMap
6475+
64306476 // ----- Name Filters --------------------------------------------------
64316477
64326478 /** A name filter selects or discards a member name of a type `pre`.
0 commit comments