Skip to content

Commit 8de8f88

Browse files
committed
change rule to require Stream of alternatives instead of Seq, add explicit flag to enable the autoContinue feature, add more examples, remove general versions to highlight this is a top-down rule
1 parent b99f2f9 commit 8de8f88

File tree

2 files changed

+139
-95
lines changed

2 files changed

+139
-95
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 96 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -622,13 +622,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
622622
* Returns alternative copies of this node where `rule` has been recursively applied to the tree.
623623
*
624624
* Users should not expect a specific directionality. If a specific directionality is needed,
625-
* multiTransformDown or multiTransformUp should be used.
625+
* multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
626626
*
627627
* @param rule a function used to generate transformed alternatives for a node
628628
* @return the stream of alternatives
629629
*/
630-
def multiTransform(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
631-
multiTransformDown(rule)
630+
def multiTransformDown(
631+
rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
632+
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
633+
}
634+
635+
/**
636+
* Returns alternative copies of this node where `rule` has been recursively applied to the tree.
637+
*
638+
* Users should not expect a specific directionality. If a specific directionality is needed,
639+
* multiTransformDownWithPruning or multiTransformUpWithPruning should be used.
640+
*
641+
* @param rule a function used to generate transformed alternatives for a node and the
642+
* `autoContinue` flag
643+
* @return the stream of alternatives
644+
*/
645+
def multiTransformDownWithContinuation(
646+
rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = {
647+
multiTransformDownWithContinuationAndPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
632648
}
633649

634650
/**
@@ -648,22 +664,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
648664
* varying initial state for different invocations.
649665
* @return the stream of alternatives
650666
*/
651-
def multiTransformWithPruning(
667+
def multiTransformDownWithPruning(
652668
cond: TreePatternBits => Boolean,
653669
ruleId: RuleId = UnknownRuleId
654-
)(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
655-
multiTransformDownWithPruning(cond, ruleId)(rule).map(_._1)
656-
}
657-
658-
/**
659-
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
660-
* of its children (pre-order).
661-
*
662-
* @param rule the function used to generate transformed alternatives for a node
663-
* @return the stream of alternatives
664-
*/
665-
def multiTransformDown(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
666-
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule).map(_._1)
670+
)(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
671+
multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ -> false))
667672
}
668673

669674
/**
@@ -675,30 +680,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
675680
* lazy `Stream` to be able to limit the number of alternatives generated at the caller side as
676681
* needed.
677682
*
678-
* To indicate that the original node without any transformation is a valid alternative the rule
679-
* can either:
680-
* - not apply or
681-
* - a `Seq` that contains a node that is equal to the original node.
683+
* The rule should not apply to indicate that the original node without any transformation is a
684+
* valid alternative.
682685
*
683-
* The rule can return `Seq.empty` to indicate that the original node should be pruned from the
684-
* alternatives.
686+
* The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
687+
* case `multiTransform` returns an empty `Stream`.
688+
*
689+
* Please consider the following examples of `input.multiTransform(rule)`:
690+
*
691+
* We have an input expression:
692+
* `Add(a, b)`
693+
*
694+
* 1.
695+
* We have a simple rule:
696+
* `a` => `Stream(1, 2)`
697+
* `b` => `Stream(10, 20)`
698+
* `Add(a, b)` => `Stream(11, 12, 21, 22)`
699+
*
700+
* The output is:
701+
* `Stream(11, 12, 21, 22)`
702+
*
703+
* 2.
704+
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
705+
* explicitly add the original `Add(a, b)` expression to the rule:
706+
* `a` => `Stream(1, 2)`
707+
* `b` => `Stream(10, 20)`
708+
* `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
709+
*
710+
* The output is:
711+
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
712+
*
713+
* 3.
714+
* It is not always easy to determine if we will do any child expression mapping but we can enable
715+
* the `autoContinue` flag to get the same result:
716+
* `a` => `(Stream(1, 2), false)`
717+
* `b` => `(Stream(10, 20), false)`
718+
* `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag and the missing
719+
* `Add(a, b)`)
720+
* The output is the same as in 2.:
721+
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
685722
*
686-
* Please note that this function always consider the original node as a valid alternative (even
687-
* if the original node is not included in the returned `Seq`) if the rule can transform any of
688-
* the descendants of the node. E.g. consider a simple expression:
689-
* `Add(a, b)`
690-
* and a rule that returns:
691-
* `Seq(1, 2)` for `a` and
692-
* `Seq(10, 20)` for `b` and
693-
* `Seq(11, 12, 21, 22)` for `Add(a, b)` (note that the original `Add(a, b)` is not returned)
694-
* then the result of `multiTransform` is:
695-
* `Seq(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`.
696723
* This feature makes the usage of `multiTransform` easier as a non-leaf transforming rule doesn't
697724
* need to take into account that it can transform a descendant node of the non-leaf node as well
698725
* and so it doesn't need return the non-leaf node itself in the list of alternatives to not stop
699726
* generating alternatives.
700727
*
701-
* @param rule a function used to generate transformed alternatives for a node
728+
* @param rule a function used to generate transformed alternatives for a node and the
729+
* `autoContinue` flag
702730
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
703731
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
704732
* T and its subtree recursively.
@@ -707,37 +735,48 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
707735
* has been marked as in effective on a TreeNode T, skips processing T and its
708736
* subtree. Do not pass it if the rule is not purely functional and reads a
709737
* varying initial state for different invocations.
710-
* @return the stream of alternatives with a flag if any transformation was done
738+
* @return the stream of alternatives
711739
*/
712-
def multiTransformDownWithPruning(
740+
def multiTransformDownWithContinuationAndPruning(
713741
cond: TreePatternBits => Boolean,
714742
ruleId: RuleId = UnknownRuleId
715-
)(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[(BaseType, Boolean)] = {
743+
)(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[BaseType] = {
744+
multiTransformDownHelper(cond, ruleId)(rule).map(_._1)
745+
}
746+
747+
private def multiTransformDownHelper(
748+
cond: TreePatternBits => Boolean,
749+
ruleId: RuleId = UnknownRuleId
750+
)(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): Stream[(BaseType, Boolean)] = {
716751
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
717752
return Stream(this -> false)
718753
}
719754

720-
val afterRules = CurrentOrigin.withOrigin(origin) {
721-
rule.applyOrElse(this, (t: BaseType) => Seq(t))
755+
var ruleApplied = true
756+
val (afterRules, autoContinue) = CurrentOrigin.withOrigin(origin) {
757+
rule.applyOrElse(this, (_: BaseType) => {
758+
ruleApplied = false
759+
Stream.empty -> false
760+
})
722761
}
723762
// A stream of a tuple that contains:
724763
// - a node that is either the transformed alternative of the current node or the current node,
725764
// - a boolean flag if the node was actually transformed,
726765
// - a boolean flag if a node's children needs to be transformed to add the node to the valid
727766
// alternatives
728-
val afterRulesStream = afterRules match {
729-
// If the rule returns with empty alternatives then prune
730-
case Nil => Stream.empty
731-
732-
// If the rule returns with a node equal to the original (or not applied) then keep the
733-
// original node
734-
case afterRule :: Nil if this fastEquals afterRule => Stream((this, false, false))
735-
736-
// If the rule is applied then use the returned alternatives
737-
case _ =>
767+
val afterRulesStream = if (afterRules.isEmpty) {
768+
if (ruleApplied) {
769+
// If the rule returned with empty alternatives then prune
770+
Stream.empty
771+
} else {
772+
// If the rule was not applied then keep the original node
773+
Stream((this, false, false))
774+
}
775+
} else {
776+
// If the rule was applied then use the returned alternatives
738777
// The alternatives can include the current node and we need to keep track of that
739778
var foundEqual = false
740-
afterRules.toStream.map { afterRule =>
779+
afterRules.map { afterRule =>
741780
(if (this fastEquals afterRule) {
742781
foundEqual = true
743782
this
@@ -746,10 +785,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
746785
afterRule
747786
}, true, false)
748787
}.append(
749-
// If the current node is not a leaf node and the alternatives returned by the rule
750-
// doesn't contain it then we need to add the current node to the stream, but require any
751-
// of its child nodes to be transformed to keep it as a valid alternative
752-
if (containsChild.nonEmpty && !foundEqual) {
788+
// If autoContinue is enabled and the current node is not a leaf node and the alternatives
789+
// returned by the rule doesn't contain the current node then we need to add the current
790+
// node to the stream, but require any of its child nodes to be transformed to keep it as
791+
// a valid alternative
792+
if (autoContinue && containsChild.nonEmpty && !foundEqual) {
753793
Stream((this, false, true))
754794
} else {
755795
Stream.empty
@@ -761,7 +801,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
761801
children.foldRight(Stream((Seq.empty[BaseType], false)))((child, childrenSeqStream) =>
762802
for {
763803
(childrenSeq, childrenSeqChanged) <- childrenSeqStream
764-
(newChild, childChanged) <- child.multiTransformDownWithPruning(cond, ruleId)(rule)
804+
(newChild, childChanged) <- child.multiTransformDownHelper(cond, ruleId)(rule)
765805
} yield (newChild +: childrenSeq) -> (childChanged || childrenSeqChanged)
766806
)
767807
}
@@ -774,7 +814,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
774814
afterRule.withNewChildren(newChildren) -> (transformed || childrenTransformed)
775815
}
776816
} else {
777-
Seq(afterRule -> transformed)
817+
Stream(afterRule -> transformed)
778818
}.map { rewritten_plan =>
779819
if (this eq rewritten_plan) {
780820
markRuleAsIneffective(ruleId)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
987987
test("multiTransformDown generates all alternatives") {
988988
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
989989
val transformed = e.multiTransformDown {
990-
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
991-
case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
990+
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
991+
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
992992
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
993-
Seq(Literal(100), Literal(200), Literal(300))
993+
Stream(Literal(100), Literal(200), Literal(300))
994994
}
995995
val expected = for {
996996
cd <- Seq(Literal(100), Literal(200), Literal(300))
@@ -1002,10 +1002,11 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
10021002

10031003
test("multiTransformDown is lazy") {
10041004
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1005-
val transformed = e.multiTransformDown {
1006-
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
1007-
case StringLiteral("b") => newErrorAfterStream(Literal(10))
1008-
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
1005+
val transformed = e.multiTransformDownWithContinuation {
1006+
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true
1007+
case StringLiteral("b") => newErrorAfterStream(Literal(10)) -> true
1008+
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
1009+
newErrorAfterStream(Literal(100)) -> true
10091010
}
10101011
val expected = for {
10111012
a <- Seq(Literal(1), Literal(2), Literal(3))
@@ -1016,28 +1017,30 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
10161017
transformed.take(3 + 1).toList
10171018
}
10181019

1019-
val transformed2 = e.multiTransformDown {
1020-
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
1021-
case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
1022-
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
1023-
}
1024-
val expected2 = for {
1025-
b <- Seq(Literal(10), Literal(20), Literal(30))
1026-
a <- Seq(Literal(1), Literal(2), Literal(3))
1027-
} yield Add(Add(a, b), Literal(100))
1028-
// We don't access alternatives for `c` after 100
1029-
assert(transformed2.take(3 * 3) === expected2)
1030-
intercept[NoSuchElementException] {
1031-
transformed.take(3 * 3 + 1).toList
1032-
}
1020+
// val transformed2 = e.multiTransformDownWithContinuation {
1021+
// case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3)) -> true
1022+
// case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30)) -> true
1023+
// case Add(StringLiteral("c"), StringLiteral("d"), _) =>
1024+
// newErrorAfterStream(Literal(100)) -> true
1025+
// }
1026+
// val expected2 = for {
1027+
// b <- Seq(Literal(10), Literal(20), Literal(30))
1028+
// a <- Seq(Literal(1), Literal(2), Literal(3))
1029+
// } yield Add(Add(a, b), Literal(100))
1030+
// // We don't access alternatives for `c` after 100
1031+
// assert(transformed2.take(3 * 3) === expected2)
1032+
// intercept[NoSuchElementException] {
1033+
// transformed.take(3 * 3 + 1).toList
1034+
// }
10331035
}
10341036

10351037
test("multiTransformDown rule return this") {
10361038
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
10371039
val transformed = e.multiTransformDown {
1038-
case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s)
1039-
case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s)
1040-
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200), a)
1040+
case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
1041+
case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
1042+
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
1043+
Stream(Literal(100), Literal(200), a)
10411044
}
10421045
val expected = for {
10431046
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
@@ -1047,15 +1050,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
10471050
assert(transformed == expected)
10481051
}
10491052

1050-
test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " +
1051-
"transformed") {
1053+
test("multiTransformDownWithContinuation doesn't stop generating alternatives of descendants " +
1054+
"when non-leaf is transformed but the itself is not in the alternatives") {
10521055
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1053-
val transformed = e.multiTransformDown {
1056+
val transformed = e.multiTransformDownWithContinuation {
10541057
case Add(StringLiteral("a"), StringLiteral("b"), _) =>
1055-
Seq(Literal(11), Literal(12), Literal(21), Literal(22))
1056-
case StringLiteral("a") => Seq(Literal(1), Literal(2))
1057-
case StringLiteral("b") => Seq(Literal(10), Literal(20))
1058-
case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200))
1058+
Stream(Literal(11), Literal(12), Literal(21), Literal(22)) -> true
1059+
case StringLiteral("a") => Stream(Literal(1), Literal(2)) -> true
1060+
case StringLiteral("b") => Stream(Literal(10), Literal(20)) -> true
1061+
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
1062+
Stream(Literal(100), Literal(200)) -> true
10591063
}
10601064
val expected = for {
10611065
cd <- Seq(Literal(100), Literal(200))
@@ -1068,15 +1072,15 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
10681072
assert(transformed == expected)
10691073
}
10701074

1071-
test("multiTransformDown non-leaf transformation if a descendant can be transformed too " +
1072-
"behaves like non-leaf returned itself") {
1075+
test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " +
1076+
"transformed and itself is in the alternatives") {
10731077
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
10741078
val transformed = e.multiTransformDown {
10751079
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
1076-
Seq(Literal(11), Literal(12), Literal(21), Literal(22), a)
1077-
case StringLiteral("a") => Seq(Literal(1), Literal(2))
1078-
case StringLiteral("b") => Seq(Literal(10), Literal(20))
1079-
case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), Literal(200))
1080+
Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
1081+
case StringLiteral("a") => Stream(Literal(1), Literal(2))
1082+
case StringLiteral("b") => Stream(Literal(10), Literal(20))
1083+
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200))
10801084
}
10811085
val expected = for {
10821086
cd <- Seq(Literal(100), Literal(200))
@@ -1092,12 +1096,12 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
10921096
test("multiTransformDown can prune") {
10931097
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
10941098
val transformed = e.multiTransformDown {
1095-
case StringLiteral("a") => Seq.empty
1099+
case StringLiteral("a") => Stream.empty
10961100
}
10971101
assert(transformed.isEmpty)
10981102

10991103
val transformed2 = e.multiTransformDown {
1100-
case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty
1104+
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
11011105
}
11021106
assert(transformed2.isEmpty)
11031107
}

0 commit comments

Comments
 (0)