Skip to content

Commit cc00448

Browse files
sameeragarwalmarmbrus
authored andcommitted
[SPARK-2042] Prevent unnecessary shuffle triggered by take()
This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that is followed by a `collect()`. This is also accompanied by adding a catalyst optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary shuffle that is sometimes triggered by `take()`. Author: Sameer Agarwal <[email protected]> Closes #1048 from sameeragarwal/master and squashes the following commits: 3eeb848 [Sameer Agarwal] Fixing Tests 1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0 b723ac4 [Sameer Agarwal] Added limit folding tests a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits 8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect() (cherry picked from commit 4107cce) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 684a93a commit cc00448

File tree

5 files changed

+97
-5
lines changed

5 files changed

+97
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ package object dsl {
175175

176176
def where(condition: Expression) = Filter(condition, logicalPlan)
177177

178+
def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
179+
178180
def join(
179181
otherPlan: LogicalPlan,
180182
joinType: JoinType = Inner,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._
2929

3030
object Optimizer extends RuleExecutor[LogicalPlan] {
3131
val batches =
32+
Batch("Combine Limits", FixedPoint(100),
33+
CombineLimits) ::
3234
Batch("ConstantFolding", FixedPoint(100),
3335
NullPropagation,
3436
ConstantFolding,
@@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] {
362364
case Cast(e, dataType) if e.dataType == dataType => e
363365
}
364366
}
367+
368+
/**
369+
* Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the
370+
* expressions into one single expression.
371+
*/
372+
object CombineLimits extends Rule[LogicalPlan] {
373+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
374+
case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
375+
Limit(If(LessThan(ne, le), ne, le), grandChild)
376+
}
377+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ case class Aggregate(
135135
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
136136
}
137137

138-
case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
138+
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
139139
def output = child.output
140-
def references = limit.references
140+
def references = limitExpr.references
141141
}
142142

143143
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.plans.logical._
21+
import org.apache.spark.sql.catalyst.rules._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
25+
class CombiningLimitsSuite extends OptimizerTest {
26+
27+
object Optimize extends RuleExecutor[LogicalPlan] {
28+
val batches =
29+
Batch("Combine Limit", FixedPoint(2),
30+
CombineLimits) ::
31+
Batch("Constant Folding", FixedPoint(3),
32+
NullPropagation,
33+
ConstantFolding,
34+
BooleanSimplification) :: Nil
35+
}
36+
37+
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
38+
39+
test("limits: combines two limits") {
40+
val originalQuery =
41+
testRelation
42+
.select('a)
43+
.limit(10)
44+
.limit(5)
45+
46+
val optimized = Optimize(originalQuery.analyze)
47+
val correctAnswer =
48+
testRelation
49+
.select('a)
50+
.limit(5).analyze
51+
52+
comparePlans(optimized, correctAnswer)
53+
}
54+
55+
test("limits: combines three limits") {
56+
val originalQuery =
57+
testRelation
58+
.select('a)
59+
.limit(2)
60+
.limit(7)
61+
.limit(5)
62+
63+
val optimized = Optimize(originalQuery.analyze)
64+
val correctAnswer =
65+
testRelation
66+
.select('a)
67+
.limit(2).analyze
68+
69+
comparePlans(optimized, correctAnswer)
70+
}
71+
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,18 @@ class SchemaRDD(
178178
def orderBy(sortExprs: SortOrder*): SchemaRDD =
179179
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
180180

181+
@deprecated("use limit with integer argument", "1.1.0")
182+
def limit(limitExpr: Expression): SchemaRDD =
183+
new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
184+
181185
/**
182-
* Limits the results by the given expressions.
186+
* Limits the results by the given integer.
183187
* {{{
184188
* schemaRDD.limit(10)
185189
* }}}
186190
*/
187-
def limit(limitExpr: Expression): SchemaRDD =
188-
new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
191+
def limit(limitNum: Int): SchemaRDD =
192+
new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
189193

190194
/**
191195
* Performs a grouping followed by an aggregation.
@@ -374,6 +378,8 @@ class SchemaRDD(
374378

375379
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
376380

381+
override def take(num: Int): Array[Row] = limit(num).collect()
382+
377383
// =======================================================================
378384
// Base RDD functions that do NOT change schema
379385
// =======================================================================

0 commit comments

Comments
 (0)