@@ -3,6 +3,7 @@ package scala.internal.quoted
33import scala .annotation .internal .sharable
44
55import scala .quoted ._
6+ import scala .quoted .matching .Bind
67import scala .tasty ._
78
89object Matcher {
@@ -30,7 +31,8 @@ object Matcher {
3031 * @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
3132 */
3233 def unapply [Tup <: Tuple ](scrutineeExpr : Expr [_])(implicit patternExpr : Expr [_], reflection : Reflection ): Option [Tup ] = {
33- import reflection ._
34+ import reflection .{Bind => BindPattern , _ }
35+
3436 // TODO improve performance
3537
3638 /** Check that the trees match and return the contents from the pattern holes.
@@ -51,6 +53,18 @@ object Matcher {
5153 sFlags.is(Lazy ) == pFlags.is(Lazy ) && sFlags.is(Mutable ) == pFlags.is(Mutable )
5254 }
5355
56+ def bindingMatch (sym : Symbol ) =
57+ Some (Tuple1 (new Bind (sym.name, sym)))
58+
59+ def hasBindTypeAnnotation (tpt : TypeTree ): Boolean = tpt match {
60+ case Annotated (tpt2, Apply (Select (New (TypeIdent (" patternBindHole" )), " <init>" ), Nil )) => true
61+ case Annotated (tpt2, _) => hasBindTypeAnnotation(tpt2)
62+ case _ => false
63+ }
64+
65+ def hasBindAnnotation (sym : Symbol ) =
66+ sym.annots.exists { case Apply (Select (New (TypeIdent (" patternBindHole" ))," <init>" ),List ()) => true ; case _ => true }
67+
5468 def treesMatch (scrutinees : List [Tree ], patterns : List [Tree ]): Option [Tuple ] =
5569 if (scrutinees.size != patterns.size) None
5670 else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _* )
@@ -142,24 +156,30 @@ object Matcher {
142156 foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
143157
144158 case (ValDef (_, tpt1, rhs1), ValDef (_, tpt2, rhs2)) if checkValFlags() =>
159+ val bindMatch =
160+ if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
161+ else Some (())
145162 val returnTptMatch = treeMatches(tpt1, tpt2)
146163 val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
147164 val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
148- foldMatchings(returnTptMatch, rhsMatchings)
165+ foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
149166
150167 case (DefDef (_, typeParams1, paramss1, tpt1, Some (rhs1)), DefDef (_, typeParams2, paramss2, tpt2, Some (rhs2))) =>
151168 val typeParmasMatch = treesMatch(typeParams1, typeParams2)
152169 val paramssMatch =
153170 if (paramss1.size != paramss2.size) None
154171 else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _* )
172+ val bindMatch =
173+ if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
174+ else Some (())
155175 val tptMatch = treeMatches(tpt1, tpt2)
156176 val rhsEnv =
157177 env + (scrutinee.symbol -> pattern.symbol) ++
158- typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
159- paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
178+ typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
179+ paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
160180 val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
161181
162- foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
182+ foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
163183
164184 case (Lambda (_, tpt1), Lambda (_, tpt2)) =>
165185 // TODO match tpt1 with tpt2?
@@ -180,6 +200,10 @@ object Matcher {
180200 val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
181201 foldMatchings(bodyMacth, casesMatch, finalizerMatch)
182202
203+ // Ignore type annotations
204+ case (Annotated (tpt, _), _) => treeMatches(tpt, pattern)
205+ case (_, Annotated (tpt, _)) => treeMatches(scrutinee, tpt)
206+
183207 // No Match
184208 case _ =>
185209 if (debug)
0 commit comments