Skip to content

Commit d8d604b

Browse files
peter-tothcloud-fan
authored andcommitted
[SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate alternatives
### What changes were proposed in this pull request? This PR introduce `TreeNode.multiTransform()` methods to be able to recursively transform a `TreeNode` (and so a tree) into multiple alternatives. These functions are particularly useful if we want to transform an expression with a projection in which subexpressions can be aliased with multiple different attributes. E.g. if we have a partitioning expression `HashPartitioning(a + b)` and we have a `Project` node that aliases `a` as `a1` and `a2` and `b` as `b1` and `b2` we can easily generate a stream of alternative transformations of the original partitioning: ``` // This is a simplified test, some arguments are missing to make it conciese val partitioning = HashPartitioning(Add(a, b)) val aliases: Map[Expression, Seq[Attribute]] = ... // collect the alias map from project val s = partitioning.multiTransform { case e: Expression if aliases.contains(e.canonicalized) => aliases(e.canonicalized) } s // Stream(HashPartitioning(Add(a1, b1)), HashPartitioning(Add(a1, b2)), HashPartitioning(Add(a2, b2)), HashPartitioning(Add(a2, b2))) ``` The result of `multiTransform` is a lazy stream to be able to limit the number of alternatives generated at the caller side as needed. ### Why are the changes needed? `TreeNode.multiTransform()` is a useful helper method. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UTs are added. Closes #38034 from peter-toth/SPARK-40599-multitransform. Authored-by: Peter Toth <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f1a3f4a commit d8d604b

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
618618
}
619619
}
620620

621+
/**
622+
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
623+
* of its children (pre-order).
624+
*
625+
* @param rule a function used to generate alternatives for a node
626+
* @return the stream of alternatives
627+
*/
628+
def multiTransformDown(
629+
rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
630+
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
631+
}
632+
633+
/**
634+
* Returns alternative copies of this node where `rule` has been recursively applied to it and all
635+
* of its children (pre-order).
636+
*
637+
* As it is very easy to generate enormous number of alternatives when the input tree is huge or
638+
* when the rule returns many alternatives for many nodes, this function returns the alternatives
639+
* as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side
640+
* as needed.
641+
*
642+
* The rule should not apply or can return a one element stream of original node to indicate that
643+
* the original node without any transformation is a valid alternative.
644+
*
645+
* The rule can return `Stream.empty` to indicate that the original node should be pruned. In this
646+
* case `multiTransform()` returns an empty `Stream`.
647+
*
648+
* Please consider the following examples of `input.multiTransformDown(rule)`:
649+
*
650+
* We have an input expression:
651+
* `Add(a, b)`
652+
*
653+
* 1.
654+
* We have a simple rule:
655+
* `a` => `Stream(1, 2)`
656+
* `b` => `Stream(10, 20)`
657+
* `Add(a, b)` => `Stream(11, 12, 21, 22)`
658+
*
659+
* The output is:
660+
* `Stream(11, 12, 21, 22)`
661+
*
662+
* 2.
663+
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
664+
* explicitly add the original `Add(a, b)` expression to the rule:
665+
* `a` => `Stream(1, 2)`
666+
* `b` => `Stream(10, 20)`
667+
* `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
668+
*
669+
* The output is:
670+
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
671+
*
672+
* @param rule a function used to generate alternatives for a node
673+
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
674+
* on a TreeNode T, skips processing T and its subtree; otherwise, processes
675+
* T and its subtree recursively.
676+
* @param ruleId is a unique Id for `rule` to prune unnecessary tree traversals. When it is
677+
* UnknownRuleId, no pruning happens. Otherwise, if `rule` (with id `ruleId`)
678+
* has been marked as in effective on a TreeNode T, skips processing T and its
679+
* subtree. Do not pass it if the rule is not purely functional and reads a
680+
* varying initial state for different invocations.
681+
* @return the stream of alternatives
682+
*/
683+
def multiTransformDownWithPruning(
684+
cond: TreePatternBits => Boolean,
685+
ruleId: RuleId = UnknownRuleId
686+
)(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
687+
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
688+
return Stream(this)
689+
}
690+
691+
// We could return `Stream(this)` if the `rule` doesn't apply and handle both
692+
// - the doesn't apply
693+
// - and the rule returns a one element `Stream(originalNode)`
694+
// cases together. But, unfortunately it doesn't seem like there is a way to match on a one
695+
// element stream without eagerly computing the tail head. So this contradicts with the purpose
696+
// of only taking the necessary elements from the alternatives. I.e. the
697+
// "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
698+
// Please note that this behaviour has a downside as well that we can only mark the rule on the
699+
// original node ineffective if the rule didn't match.
700+
var ruleApplied = true
701+
val afterRules = CurrentOrigin.withOrigin(origin) {
702+
rule.applyOrElse(this, (_: BaseType) => {
703+
ruleApplied = false
704+
Stream.empty
705+
})
706+
}
707+
708+
val afterRulesStream = if (afterRules.isEmpty) {
709+
if (ruleApplied) {
710+
// If the rule returned with empty alternatives then prune
711+
Stream.empty
712+
} else {
713+
// If the rule was not applied then keep the original node
714+
this.markRuleAsIneffective(ruleId)
715+
Stream(this)
716+
}
717+
} else {
718+
// If the rule was applied then use the returned alternatives
719+
afterRules.map { afterRule =>
720+
if (this fastEquals afterRule) {
721+
this
722+
} else {
723+
afterRule.copyTagsFrom(this)
724+
afterRule
725+
}
726+
}
727+
}
728+
729+
afterRulesStream.flatMap { afterRule =>
730+
if (afterRule.containsChild.nonEmpty) {
731+
generateChildrenSeq(
732+
afterRule.children.map(_.multiTransformDownWithPruning(cond, ruleId)(rule)))
733+
.map(afterRule.withNewChildren)
734+
} else {
735+
Stream(afterRule)
736+
}
737+
}
738+
}
739+
740+
private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): Stream[Seq[T]] = {
741+
childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, childrenSeqStream) =>
742+
for {
743+
childrenSeq <- childrenSeqStream
744+
child <- childrenStream
745+
} yield child +: childrenSeq
746+
)
747+
}
748+
621749
/**
622750
* Returns a copy of this node where `f` has been applied to all the nodes in `children`.
623751
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
977977
assert(origin.context.summary.isEmpty)
978978
}
979979
}
980+
981+
private def newErrorAfterStream(es: Expression*) = {
982+
es.toStream.append(
983+
throw new NoSuchElementException("Stream should not return more elements")
984+
)
985+
}
986+
987+
test("multiTransformDown generates all alternatives") {
988+
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
989+
val transformed = e.multiTransformDown {
990+
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
991+
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
992+
case Add(StringLiteral("c"), StringLiteral("d"), _) =>
993+
Stream(Literal(100), Literal(200), Literal(300))
994+
}
995+
val expected = for {
996+
cd <- Seq(Literal(100), Literal(200), Literal(300))
997+
b <- Seq(Literal(10), Literal(20), Literal(30))
998+
a <- Seq(Literal(1), Literal(2), Literal(3))
999+
} yield Add(Add(a, b), cd)
1000+
assert(transformed === expected)
1001+
}
1002+
1003+
test("multiTransformDown is lazy") {
1004+
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1005+
val transformed = e.multiTransformDown {
1006+
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
1007+
case StringLiteral("b") => newErrorAfterStream(Literal(10))
1008+
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
1009+
}
1010+
val expected = for {
1011+
a <- Seq(Literal(1), Literal(2), Literal(3))
1012+
} yield Add(Add(a, Literal(10)), Literal(100))
1013+
// We don't access alternatives for `b` after 10 and for `c` after 100
1014+
assert(transformed.take(3) == expected)
1015+
intercept[NoSuchElementException] {
1016+
transformed.take(3 + 1).toList
1017+
}
1018+
1019+
val transformed2 = e.multiTransformDown {
1020+
case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
1021+
case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
1022+
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
1023+
}
1024+
val expected2 = for {
1025+
b <- Seq(Literal(10), Literal(20), Literal(30))
1026+
a <- Seq(Literal(1), Literal(2), Literal(3))
1027+
} yield Add(Add(a, b), Literal(100))
1028+
// We don't access alternatives for `c` after 100
1029+
assert(transformed2.take(3 * 3) === expected2)
1030+
intercept[NoSuchElementException] {
1031+
transformed.take(3 * 3 + 1).toList
1032+
}
1033+
}
1034+
1035+
test("multiTransformDown rule return this") {
1036+
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1037+
val transformed = e.multiTransformDown {
1038+
case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
1039+
case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
1040+
case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
1041+
Stream(Literal(100), Literal(200), a)
1042+
}
1043+
val expected = for {
1044+
cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
1045+
b <- Seq(Literal(10), Literal(20), Literal("b"))
1046+
a <- Seq(Literal(1), Literal(2), Literal("a"))
1047+
} yield Add(Add(a, b), cd)
1048+
assert(transformed == expected)
1049+
}
1050+
1051+
test("multiTransformDown doesn't stop generating alternatives of descendants when non-leaf is " +
1052+
"transformed and itself is in the alternatives") {
1053+
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1054+
val transformed = e.multiTransformDown {
1055+
case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
1056+
Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
1057+
case StringLiteral("a") => Stream(Literal(1), Literal(2))
1058+
case StringLiteral("b") => Stream(Literal(10), Literal(20))
1059+
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream(Literal(100), Literal(200))
1060+
}
1061+
val expected = for {
1062+
cd <- Seq(Literal(100), Literal(200))
1063+
ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
1064+
(for {
1065+
b <- Seq(Literal(10), Literal(20))
1066+
a <- Seq(Literal(1), Literal(2))
1067+
} yield Add(a, b))
1068+
} yield Add(ab, cd)
1069+
assert(transformed == expected)
1070+
}
1071+
1072+
test("multiTransformDown can prune") {
1073+
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
1074+
val transformed = e.multiTransformDown {
1075+
case StringLiteral("a") => Stream.empty
1076+
}
1077+
assert(transformed.isEmpty)
1078+
1079+
val transformed2 = e.multiTransformDown {
1080+
case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
1081+
}
1082+
assert(transformed2.isEmpty)
1083+
}
9801084
}

0 commit comments

Comments
 (0)