Skip to content

Commit 0c1b2df

Browse files
MickDaviesmarmbrus
authored andcommitted
[SPARK-8077] [SQL] Optimization for TreeNodes with large numbers of children
For example large IN clauses Large IN clauses are parsed very slowly. For example SQL below (10K items in IN) takes 45-50s. s"""SELECT * FROM Person WHERE ForeName IN ('${(1 to 10000).map("n" + _).mkString("','")}')""" This is principally due to TreeNode which repeatedly call contains on children, where children in this case is a List that is 10K long. In effect parsing for large IN clauses is O(N squared). A lazily initialised Set based on children for contains reduces parse time to around 2.5s Author: Michael Davies <[email protected]> Closes #6673 from MickDavies/SPARK-8077 and squashes the following commits: 38cd425 [Michael Davies] SPARK-8077: Optimization for TreeNodes with large numbers of children d80103b [Michael Davies] SPARK-8077: Optimization for TreeNodes with large numbers of children e6be8be [Michael Davies] SPARK-8077: Optimization for TreeNodes with large numbers of children
1 parent 50a0496 commit 0c1b2df

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
9090
val input = children.flatMap(_.output)
9191
productIterator.map {
9292
// Children are checked using sameResult above.
93-
case tn: TreeNode[_] if children contains tn => null
93+
case tn: TreeNode[_] if containsChild(tn) => null
9494
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
9595
case s: Option[_] => s.map {
9696
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
5959

6060
val origin: Origin = CurrentOrigin.get
6161

62-
/** Returns a Seq of the children of this node */
62+
/**
63+
* Returns a Seq of the children of this node.
64+
* Children should not change. Immutability required for containsChild optimization
65+
*/
6366
def children: Seq[BaseType]
6467

68+
lazy val containsChild: Set[TreeNode[_]] = children.toSet
69+
6570
/**
6671
* Faster version of equality which short-circuits when two treeNodes are the same instance.
6772
* We don't just override Object.equals, as doing so prevents the scala compiler from
@@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
147152
def mapChildren(f: BaseType => BaseType): this.type = {
148153
var changed = false
149154
val newArgs = productIterator.map {
150-
case arg: TreeNode[_] if children contains arg =>
155+
case arg: TreeNode[_] if containsChild(arg) =>
151156
val newChild = f(arg.asInstanceOf[BaseType])
152157
if (newChild fastEquals arg) {
153158
arg
@@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
173178
val newArgs = productIterator.map {
174179
// Handle Seq[TreeNode] in TreeNode parameters.
175180
case s: Seq[_] => s.map {
176-
case arg: TreeNode[_] if children contains arg =>
181+
case arg: TreeNode[_] if containsChild(arg) =>
177182
val newChild = remainingNewChildren.remove(0)
178183
val oldChild = remainingOldChildren.remove(0)
179184
if (newChild fastEquals oldChild) {
@@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
185190
case nonChild: AnyRef => nonChild
186191
case null => null
187192
}
188-
case arg: TreeNode[_] if children contains arg =>
193+
case arg: TreeNode[_] if containsChild(arg) =>
189194
val newChild = remainingNewChildren.remove(0)
190195
val oldChild = remainingOldChildren.remove(0)
191196
if (newChild fastEquals oldChild) {
@@ -238,15 +243,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
238243
def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
239244
var changed = false
240245
val newArgs = productIterator.map {
241-
case arg: TreeNode[_] if children contains arg =>
246+
case arg: TreeNode[_] if containsChild(arg) =>
242247
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
243248
if (!(newChild fastEquals arg)) {
244249
changed = true
245250
newChild
246251
} else {
247252
arg
248253
}
249-
case Some(arg: TreeNode[_]) if children contains arg =>
254+
case Some(arg: TreeNode[_]) if containsChild(arg) =>
250255
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
251256
if (!(newChild fastEquals arg)) {
252257
changed = true
@@ -257,7 +262,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
257262
case m: Map[_, _] => m
258263
case d: DataType => d // Avoid unpacking Structs
259264
case args: Traversable[_] => args.map {
260-
case arg: TreeNode[_] if children contains arg =>
265+
case arg: TreeNode[_] if containsChild(arg) =>
261266
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
262267
if (!(newChild fastEquals arg)) {
263268
changed = true
@@ -295,15 +300,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
295300
def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
296301
var changed = false
297302
val newArgs = productIterator.map {
298-
case arg: TreeNode[_] if children contains arg =>
303+
case arg: TreeNode[_] if containsChild(arg) =>
299304
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
300305
if (!(newChild fastEquals arg)) {
301306
changed = true
302307
newChild
303308
} else {
304309
arg
305310
}
306-
case Some(arg: TreeNode[_]) if children contains arg =>
311+
case Some(arg: TreeNode[_]) if containsChild(arg) =>
307312
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
308313
if (!(newChild fastEquals arg)) {
309314
changed = true
@@ -314,7 +319,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
314319
case m: Map[_, _] => m
315320
case d: DataType => d // Avoid unpacking Structs
316321
case args: Traversable[_] => args.map {
317-
case arg: TreeNode[_] if children contains arg =>
322+
case arg: TreeNode[_] if containsChild(arg) =>
318323
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
319324
if (!(newChild fastEquals arg)) {
320325
changed = true
@@ -383,7 +388,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
383388

384389
/** Returns a string representing the arguments to this node, minus any children */
385390
def argString: String = productIterator.flatMap {
386-
case tn: TreeNode[_] if children contains tn => Nil
391+
case tn: TreeNode[_] if containsChild(tn) => Nil
387392
case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil
388393
case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil
389394
case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil

0 commit comments

Comments
 (0)