@@ -3,6 +3,7 @@ package scala.internal.quoted
33import scala .annotation .internal .sharable
44
55import scala .quoted ._
6+ import scala .quoted .matching .Binding
67import scala .tasty ._
78
89object Matcher {
@@ -51,6 +52,18 @@ object Matcher {
5152 sFlags.is(Lazy ) == pFlags.is(Lazy ) && sFlags.is(Mutable ) == pFlags.is(Mutable )
5253 }
5354
55+ def bindingMatch (sym : Symbol ) =
56+ Some (Tuple1 (new Binding (sym.name, sym)))
57+
58+ def hasBindingTypeAnnotation (tpt : TypeTree ): Boolean = tpt match {
59+ case Annotated (tpt2, Apply (Select (New (TypeIdent (" patternBindHole" )), " <init>" ), Nil )) => true
60+ case Annotated (tpt2, _) => hasBindingTypeAnnotation(tpt2)
61+ case _ => false
62+ }
63+
64+ def hasBindingAnnotation (sym : Symbol ) =
65+ sym.annots.exists { case Apply (Select (New (TypeIdent (" patternBindHole" ))," <init>" ),List ()) => true ; case _ => true }
66+
5467 def treesMatch (scrutinees : List [Tree ], patterns : List [Tree ]): Option [Tuple ] =
5568 if (scrutinees.size != patterns.size) None
5669 else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _* )
@@ -142,24 +155,30 @@ object Matcher {
142155 foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
143156
144157 case (ValDef (_, tpt1, rhs1), ValDef (_, tpt2, rhs2)) if checkValFlags() =>
158+ val bindMatch =
159+ if (hasBindingAnnotation(pattern.symbol) || hasBindingTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
160+ else Some (())
145161 val returnTptMatch = treeMatches(tpt1, tpt2)
146162 val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
147163 val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
148- foldMatchings(returnTptMatch, rhsMatchings)
164+ foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
149165
150166 case (DefDef (_, typeParams1, paramss1, tpt1, Some (rhs1)), DefDef (_, typeParams2, paramss2, tpt2, Some (rhs2))) =>
151167 val typeParmasMatch = treesMatch(typeParams1, typeParams2)
152168 val paramssMatch =
153169 if (paramss1.size != paramss2.size) None
154170 else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _* )
171+ val bindMatch =
172+ if (hasBindingAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
173+ else Some (())
155174 val tptMatch = treeMatches(tpt1, tpt2)
156175 val rhsEnv =
157176 env + (scrutinee.symbol -> pattern.symbol) ++
158- typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
159- paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
177+ typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
178+ paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
160179 val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
161180
162- foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
181+ foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
163182
164183 case (Lambda (_, tpt1), Lambda (_, tpt2)) =>
165184 // TODO match tpt1 with tpt2?
@@ -180,6 +199,10 @@ object Matcher {
180199 val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
181200 foldMatchings(bodyMacth, casesMatch, finalizerMatch)
182201
202+ // Ignore type annotations
203+ case (Annotated (tpt, _), _) => treeMatches(tpt, pattern)
204+ case (_, Annotated (tpt, _)) => treeMatches(scrutinee, tpt)
205+
183206 // No Match
184207 case _ =>
185208 if (debug)
0 commit comments