From e6be8beb72936bb457343e6c9bd0dfddeede040f Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Fri, 5 Jun 2015 19:02:15 +0100 Subject: [PATCH 1/3] SPARK-8077: Optimization for TreeNodes with large numbers of children For example large IN clauses --- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 22 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dba69659afc8..9b521cd52bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { val input = children.flatMap(_.output) productIterator.map { // Children are checked using sameResult above. - case tn: TreeNode[_] if children contains tn => null + case tn: TreeNode[_] if childrenSet contains tn => null case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) case s: Option[_] => s.map { case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 36d005d0e168..3788add9cae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -62,6 +62,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a Seq of the children of this node */ def children: Seq[BaseType] + lazy val childrenSet:Set[TreeNode[_]] = children.toSet + /** * Faster version of equality which short-circuits when two treeNodes are the same instance. * We don't just override Object.equals, as doing so prevents the scala compiler from @@ -147,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def mapChildren(f: BaseType => BaseType): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { arg @@ -173,7 +175,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -185,7 +187,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case nonChild: AnyRef => nonChild case null => null } - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -238,7 +240,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -246,7 +248,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -257,7 +259,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -295,7 +297,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -303,7 +305,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -314,7 +316,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if childrenSet contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -383,7 +385,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if children contains tn => Nil + case tn: TreeNode[_] if childrenSet contains tn => Nil case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil From d80103b88d7a147928980d242a9d52fc2ece9412 Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Fri, 12 Jun 2015 08:07:05 +0100 Subject: [PATCH 2/3] SPARK-8077: Optimization for TreeNodes with large numbers of children Correct style and rename attribute --- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9b521cd52bda..c8c6676f24c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { val input = children.flatMap(_.output) productIterator.map { // Children are checked using sameResult above. - case tn: TreeNode[_] if childrenSet contains tn => null + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) case s: Option[_] => s.map { case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 3788add9cae5..04776080d373 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -62,7 +62,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a Seq of the children of this node */ def children: Seq[BaseType] - lazy val childrenSet:Set[TreeNode[_]] = children.toSet + lazy val containsChild: Set[TreeNode[_]] = children.toSet /** * Faster version of equality which short-circuits when two treeNodes are the same instance. @@ -149,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def mapChildren(f: BaseType => BaseType): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { arg @@ -175,7 +175,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -187,7 +187,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case nonChild: AnyRef => nonChild case null => null } - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -240,7 +240,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -248,7 +248,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if childrenSet contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -259,7 +259,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -297,7 +297,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -305,7 +305,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if childrenSet contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -316,7 +316,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if childrenSet contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -385,7 +385,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if childrenSet contains tn => Nil + case tn: TreeNode[_] if containsChild(tn) => Nil case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil From 38cd42549ea39188216d43998230efa474bf546b Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Fri, 12 Jun 2015 08:14:00 +0100 Subject: [PATCH 3/3] SPARK-8077: Optimization for TreeNodes with large numbers of children Add comment about required immutability of children --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 04776080d373..e3c04fc9c171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -59,7 +59,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val origin: Origin = CurrentOrigin.get - /** Returns a Seq of the children of this node */ + /** + * Returns a Seq of the children of this node. + * Children should not change. Immutability required for containsChild optimization + */ def children: Seq[BaseType] lazy val containsChild: Set[TreeNode[_]] = children.toSet