@@ -622,13 +622,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
622622 * Returns alternative copies of this node where `rule` has been recursively applied to the tree.
623623 *
624624 * Users should not expect a specific directionality. If a specific directionality is needed,
625- * multiTransformDown or multiTransformUp should be used.
625+ * multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
626626 *
627627 * @param rule a function used to generate transformed alternatives for a node
628628 * @return the stream of alternatives
629629 */
630- def multiTransform (rule : PartialFunction [BaseType , Seq [BaseType ]]): Stream [BaseType ] = {
631- multiTransformDown(rule)
630+ def multiTransformDown (
631+ rule : PartialFunction [BaseType , Stream [BaseType ]]): Stream [BaseType ] = {
632+ multiTransformDownWithPruning(AlwaysProcess .fn, UnknownRuleId )(rule)
633+ }
634+
635+ /**
636+ * Returns alternative copies of this node where `rule` has been recursively applied to the tree.
637+ *
638+ * Users should not expect a specific directionality. If a specific directionality is needed,
639+ * multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
640+ *
641+ * @param rule a function used to generate transformed alternatives for a node and the
642+ * `autoContinue` flag
643+ * @return the stream of alternatives
644+ */
645+ def multiTransformDownWithContinuation (
646+ rule : PartialFunction [BaseType , (Stream [BaseType ], Boolean )]): Stream [BaseType ] = {
647+ multiTransformDownWithContinuationAndPruning(AlwaysProcess .fn, UnknownRuleId )(rule)
632648 }
633649
634650 /**
@@ -648,22 +664,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
648664 * varying initial state for different invocations.
649665 * @return the stream of alternatives
650666 */
651- def multiTransformWithPruning (
667+ def multiTransformDownWithPruning (
652668 cond : TreePatternBits => Boolean ,
653669 ruleId : RuleId = UnknownRuleId
654- )(rule : PartialFunction [BaseType , Seq [BaseType ]]): Stream [BaseType ] = {
655- multiTransformDownWithPruning(cond, ruleId)(rule).map(_._1)
656- }
657-
658- /**
659- * Returns alternative copies of this node where `rule` has been recursively applied to it and all
660- * of its children (pre-order).
661- *
662- * @param rule the function used to generate transformed alternatives for a node
663- * @return the stream of alternatives
664- */
665- def multiTransformDown (rule : PartialFunction [BaseType , Seq [BaseType ]]): Stream [BaseType ] = {
666- multiTransformDownWithPruning(AlwaysProcess .fn, UnknownRuleId )(rule).map(_._1)
670+ )(rule : PartialFunction [BaseType , Stream [BaseType ]]): Stream [BaseType ] = {
671+ multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ -> false ))
667672 }
668673
669674 /**
@@ -675,30 +680,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
675680 * lazy `Stream` to be able to limit the number of alternatives generated at the caller side as
676681 * needed.
677682 *
678- * To indicate that the original node without any transformation is a valid alternative the rule
679- * can either:
680- * - not apply or
681- * - a `Seq` that contains a node that is equal to the original node.
683+ * The rule should not apply to indicate that the original node without any transformation is a
684+ * valid alternative.
682685 *
683- * The rule can return `Seq.empty` to indicate that the original node should be pruned from the
684- * alternatives.
686+ * The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
687+ * case `multiTransform` returns an empty `Stream`.
688+ *
689+ * Please consider the following examples of `input.multiTransform(rule)`:
690+ *
691+ * We have an input expression:
692+ * `Add(a, b)`
693+ *
694+ * 1.
695+ * We have a simple rule:
696+ * `a` => `Stream(1, 2)`
697+ * `b` => `Stream(10, 20)`
698+ * `Add(a, b)` => `Stream(11, 12, 21, 22)`
699+ *
700+ * The output is:
701+ * `Stream(11, 12, 21, 22)`
702+ *
703+ * 2.
704+ * In the previous example if we want to generate alternatives of `a` and `b` too then we need to
705+ * explicitly add the original `Add(a, b)` expression to the rule:
706+ * `a` => `Stream(1, 2)`
707+ * `b` => `Stream(10, 20)`
708+ * `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
709+ *
710+ * The output is:
711+ * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
712+ *
713+ * 3.
714+ * It is not always easy to determine if we will do any child expression mapping but we can enable
715+ * the `autoContinue` flag to get the same result:
716+ * `a` => `(Stream(1, 2), false)`
717+ * `b` => `(Stream(10, 20), false)`
718+ * `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag and the missing
719+ * `Add(a, b)`)
720+ * The output is the same as in 2.:
721+ * `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
685722 *
686- * Please note that this function always consider the original node as a valid alternative (even
687- * if the original node is not included in the returned `Seq`) if the rule can transform any of
688- * the descendants of the node. E.g. consider a simple expression:
689- * `Add(a, b)`
690- * and a rule that returns:
691- * `Seq(1, 2)` for `a` and
692- * `Seq(10, 20)` for `b` and
693- * `Seq(11, 12, 21, 22)` for `Add(a, b)` (note that the original `Add(a, b)` is not returned)
694- * then the result of `multiTransform` is:
695- * `Seq(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`.
696723 * This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't
697724 * need to take into account that it can transform a descendant node of the non-leaf node as well
698725 * and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop
699726 * generating alternatives.
700727 *
701- * @param rule a function used to generate transformed alternatives for a node
728+ * @param rule a function used to generate transformed alternatives for a node and the
729+ * `autoContinue` flag
702730 * @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
703731 * on a TreeNode T, skips processing T and its subtree; otherwise, processes
704732 * T and its subtree recursively.
@@ -707,37 +735,48 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
707735 * has been marked as in effective on a TreeNode T, skips processing T and its
708736 * subtree. Do not pass it if the rule is not purely functional and reads a
709737 * varying initial state for different invocations.
710- * @return the stream of alternatives with a flag if any transformation was done
738+ * @return the stream of alternatives
711739 */
712- def multiTransformDownWithPruning (
740+ def multiTransformDownWithContinuationAndPruning (
713741 cond : TreePatternBits => Boolean ,
714742 ruleId : RuleId = UnknownRuleId
715- )(rule : PartialFunction [BaseType , Seq [BaseType ]]): Stream [(BaseType , Boolean )] = {
743+ )(rule : PartialFunction [BaseType , (Stream [BaseType ], Boolean )]): Stream [BaseType ] = {
744+ multiTransformDownHelper(cond, ruleId)(rule).map(_._1)
745+ }
746+
747+ private def multiTransformDownHelper (
748+ cond : TreePatternBits => Boolean ,
749+ ruleId : RuleId = UnknownRuleId
750+ )(rule : PartialFunction [BaseType , (Stream [BaseType ], Boolean )]): Stream [(BaseType , Boolean )] = {
716751 if (! cond.apply(this ) || isRuleIneffective(ruleId)) {
717752 return Stream (this -> false )
718753 }
719754
720- val afterRules = CurrentOrigin .withOrigin(origin) {
721- rule.applyOrElse(this , (t : BaseType ) => Seq (t))
755+ var ruleApplied = true
756+ val (afterRules, autoContinue) = CurrentOrigin .withOrigin(origin) {
757+ rule.applyOrElse(this , (_ : BaseType ) => {
758+ ruleApplied = false
759+ Stream .empty -> false
760+ })
722761 }
723762 // A stream of a tuple that contains:
724763 // - a node that is either the transformed alternative of the current node or the current node,
725764 // - a boolean flag if the node was actually transformed,
726765 // - a boolean flag if a node's children needs to be transformed to add the node to the valid
727766 // alternatives
728- val afterRulesStream = afterRules match {
729- // If the rule returns with empty alternatives then prune
730- case Nil => Stream . empty
731-
732- // If the rule returns with a node equal to the original (or not applied) then keep the
733- // original node
734- case afterRule :: Nil if this fastEquals afterRule => Stream ((this , false , false ))
735-
736- // If the rule is applied then use the returned alternatives
737- case _ =>
767+ val afterRulesStream = if (afterRules.isEmpty) {
768+ if (ruleApplied) {
769+ // If the rule returned with empty alternatives then prune
770+ Stream .empty
771+ } else {
772+ // If the rule was not applied then keep the original node
773+ Stream ((this , false , false ))
774+ }
775+ } else {
776+ // If the rule was applied then use the returned alternatives
738777 // The alternatives can include the current node and we need to keep track of that
739778 var foundEqual = false
740- afterRules.toStream. map { afterRule =>
779+ afterRules.map { afterRule =>
741780 (if (this fastEquals afterRule) {
742781 foundEqual = true
743782 this
@@ -746,10 +785,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
746785 afterRule
747786 }, true , false )
748787 }.append(
749- // If the current node is not a leaf node and the alternatives returned by the rule
750- // doesn't contain it then we need to add the current node to the stream, but require any
751- // of its child nodes to be transformed to keep it as a valid alternative
752- if (containsChild.nonEmpty && ! foundEqual) {
788+ // If autoContinue is enabled and the current node is not a leaf node and the alternatives
789+ // returned by the rule doesn't contain the current node then we need to add the current
790+ // node to the stream, but require any of its child nodes to be transformed to keep it as
791+ // a valid alternative
792+ if (autoContinue && containsChild.nonEmpty && ! foundEqual) {
753793 Stream ((this , false , true ))
754794 } else {
755795 Stream .empty
@@ -761,7 +801,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
761801 children.foldRight(Stream ((Seq .empty[BaseType ], false )))((child, childrenSeqStream) =>
762802 for {
763803 (childrenSeq, childrenSeqChanged) <- childrenSeqStream
764- (newChild, childChanged) <- child.multiTransformDownWithPruning (cond, ruleId)(rule)
804+ (newChild, childChanged) <- child.multiTransformDownHelper (cond, ruleId)(rule)
765805 } yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged)
766806 )
767807 }
@@ -774,7 +814,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
774814 afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed)
775815 }
776816 } else {
777- Seq (afterRule -> transformed)
817+ Stream (afterRule -> transformed)
778818 }.map { rewritten_plan =>
779819 if (this eq rewritten_plan) {
780820 markRuleAsIneffective(ruleId)
0 commit comments