@@ -484,31 +484,54 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
484
484
485
485
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
486
486
// before splitting the LHS into its constituents. That way, the RHS variables are
487
- // constraint by the hard union and can be instantiated to it. If we just split and add
487
+ // constrained by the hard union and can be instantiated to it. If we just split and add
488
488
// the two parts of the LHS separately to the constraint, the lower bound would become
489
489
// a soft union.
490
490
def constrainRHSVars (tp2 : Type ): Boolean = tp2.dealiasKeepRefiningAnnots match
491
491
case tp2 : TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
492
492
case AndType (tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
493
493
case _ => true
494
494
495
- widenOK
496
- || joinOK
497
- || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
498
- || containsAnd(tp1)
499
- && ! joined
500
- && {
501
- joined = true
502
- try inFrozenGadt(recur(tp1.join, tp2))
503
- finally joined = false
504
- }
505
- // An & on the left side loses information. We compensate by also trying the join.
506
- // This is less ad-hoc than it looks since we produce joins in type inference,
507
- // and then need to check that they are indeed supertypes of the original types
508
- // under -Ycheck. Test case is i7965.scala.
509
- // On the other hand, we could get a combinatorial explosion by applying such joins
510
- // recursively, so we do it only once. See i14870.scala as a test case, which would
511
- // loop for a very long time without the recursion brake.
495
+ /** Mark toplevel type vars in `tp2` as hard in the current typerState */
496
+ def hardenTypeVars (tp2 : Type ): Unit = tp2.dealiasKeepRefiningAnnots match
497
+ case tvar : TypeVar if constraint.contains(tvar.origin) =>
498
+ state.hardVars += tvar
499
+ case tp2 : TypeParamRef if constraint.contains(tp2) =>
500
+ hardenTypeVars(constraint.typeVarOfParam(tp2))
501
+ case tp2 : AndOrType =>
502
+ hardenTypeVars(tp2.tp1)
503
+ hardenTypeVars(tp2.tp2)
504
+ case _ =>
505
+
506
+ val res = widenOK
507
+ || joinOK
508
+ || (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
509
+ || containsAnd(tp1)
510
+ && ! joined
511
+ && {
512
+ joined = true
513
+ try inFrozenGadt(recur(tp1.join, tp2))
514
+ finally joined = false
515
+ }
516
+ // An & on the left side loses information. We compensate by also trying the join.
517
+ // This is less ad-hoc than it looks since we produce joins in type inference,
518
+ // and then need to check that they are indeed supertypes of the original types
519
+ // under -Ycheck. Test case is i7965.scala.
520
+ // On the other hand, we could get a combinatorial explosion by applying such joins
521
+ // recursively, so we do it only once. See i14870.scala as a test case, which would
522
+ // loop for a very long time without the recursion brake.
523
+
524
+ if res && ! tp1.isSoft then
525
+ // We use a heuristic here where every toplevel type variable on the right hand side
526
+ // is marked so that it converts all soft unions in its lower bound to hard unions
527
+ // before it is instantiated. The reason is that the union might have come from
528
+ // (decomposed and reconstituted) `tp1`. But of course there might be false positives
529
+ // where we also treat unions that come from elsewhere as hard unions. Or the constraint
530
+ // that created the union is ultimately thrown away, but the type variable will
531
+ // stay marked. So it is a coarse measure to take. But it works in the obvious cases.
532
+ hardenTypeVars(tp2)
533
+
534
+ res
512
535
513
536
case tp1 : MatchType =>
514
537
val reduced = tp1.reduced
@@ -2863,8 +2886,8 @@ object TypeComparer {
2863
2886
def subtypeCheckInProgress (using Context ): Boolean =
2864
2887
comparing(_.subtypeCheckInProgress)
2865
2888
2866
- def instanceType (param : TypeParamRef , fromBelow : Boolean )(using Context ): Type =
2867
- comparing(_.instanceType(param, fromBelow))
2889
+ def instanceType (param : TypeParamRef , fromBelow : Boolean , widenUnions : Boolean )(using Context ): Type =
2890
+ comparing(_.instanceType(param, fromBelow, widenUnions ))
2868
2891
2869
2892
def approximation (param : TypeParamRef , fromBelow : Boolean )(using Context ): Type =
2870
2893
comparing(_.approximation(param, fromBelow))
@@ -2884,8 +2907,8 @@ object TypeComparer {
2884
2907
def addToConstraint (tl : TypeLambda , tvars : List [TypeVar ])(using Context ): Boolean =
2885
2908
comparing(_.addToConstraint(tl, tvars))
2886
2909
2887
- def widenInferred (inst : Type , bound : Type )(using Context ): Type =
2888
- comparing(_.widenInferred(inst, bound))
2910
+ def widenInferred (inst : Type , bound : Type , widenUnions : Boolean )(using Context ): Type =
2911
+ comparing(_.widenInferred(inst, bound, widenUnions ))
2889
2912
2890
2913
def dropTransparentTraits (tp : Type , bound : Type )(using Context ): Type =
2891
2914
comparing(_.dropTransparentTraits(tp, bound))
0 commit comments