Skip to content

Commit f04b567

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-7289] handle project -> limit -> sort efficiently
make the `TakeOrdered` strategy and operator more general, such that it can optionally handle a projection when necessary Author: Wenchen Fan <[email protected]> Closes apache#6780 from cloud-fan/limit and squashes the following commits: 34aa07b [Wenchen Fan] revert 07d5456 [Wenchen Fan] clean closure 20821ec [Wenchen Fan] fix 3676a82 [Wenchen Fan] address comments b558549 [Wenchen Fan] address comments 214842b [Wenchen Fan] fix style 2d8be83 [Wenchen Fan] add LimitPushDown 948f740 [Wenchen Fan] fix existing
1 parent b84d4b4 commit f04b567

File tree

8 files changed

+62
-40
lines changed

8 files changed

+62
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer {
3939
Batch("Distinct", FixedPoint(100),
4040
ReplaceDistinctWithAggregate) ::
4141
Batch("Operator Optimizations", FixedPoint(100),
42-
UnionPushdown,
43-
CombineFilters,
42+
// Operator push down
43+
UnionPushDown,
44+
PushPredicateThroughJoin,
4445
PushPredicateThroughProject,
4546
PushPredicateThroughGenerate,
4647
ColumnPruning,
48+
// Operator combine
4749
ProjectCollapsing,
50+
CombineFilters,
4851
CombineLimits,
52+
// Constant folding
4953
NullPropagation,
5054
OptimizeIn,
5155
ConstantFolding,
5256
LikeSimplification,
5357
BooleanSimplification,
54-
PushPredicateThroughJoin,
5558
RemovePositive,
5659
SimplifyFilters,
5760
SimplifyCasts,
@@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer {
6366
}
6467

6568
/**
66-
* Pushes operations to either side of a Union.
67-
*/
68-
object UnionPushdown extends Rule[LogicalPlan] {
69+
* Pushes operations to either side of a Union.
70+
*/
71+
object UnionPushDown extends Rule[LogicalPlan] {
6972

7073
/**
71-
* Maps Attributes from the left side to the corresponding Attribute on the right side.
72-
*/
73-
def buildRewrites(union: Union): AttributeMap[Attribute] = {
74+
* Maps Attributes from the left side to the corresponding Attribute on the right side.
75+
*/
76+
private def buildRewrites(union: Union): AttributeMap[Attribute] = {
7477
assert(union.left.output.size == union.right.output.size)
7578

7679
AttributeMap(union.left.output.zip(union.right.output))
7780
}
7881

7982
/**
80-
* Rewrites an expression so that it can be pushed to the right side of a Union operator.
81-
* This method relies on the fact that the output attributes of a union are always equal
82-
* to the left child's output.
83-
*/
84-
def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = {
83+
* Rewrites an expression so that it can be pushed to the right side of a Union operator.
84+
* This method relies on the fact that the output attributes of a union are always equal
85+
* to the left child's output.
86+
*/
87+
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
8588
val result = e transform {
8689
case a: Attribute => rewrites(a)
8790
}
@@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
108111
}
109112
}
110113

111-
112114
/**
113115
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
114116
* transformations:
@@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
117119
* - Aggregate
118120
* - Project <- Join
119121
* - LeftSemiJoin
120-
* - Performing alias substitution.
121122
*/
122123
object ColumnPruning extends Rule[LogicalPlan] {
123124
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
159160

160161
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
161162

163+
// Push down project through limit, so that we may have chance to push it further.
162164
case Project(projectList, Limit(exp, child)) =>
163165
Limit(exp, Project(projectList, child))
164166

165-
// push down project if possible when the child is sort
167+
// Push down project if possible when the child is sort
166168
case p @ Project(projectList, s @ Sort(_, _, grandChild))
167169
if s.references.subsetOf(p.outputSet) =>
168170
s.copy(child = Project(projectList, grandChild))
@@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
181183
}
182184

183185
/**
184-
* Combines two adjacent [[Project]] operators into one, merging the
185-
* expressions into one single expression.
186+
* Combines two adjacent [[Project]] operators into one and perform alias substitution,
187+
* merging the expressions into one single expression.
186188
*/
187189
object ProjectCollapsing extends Rule[LogicalPlan] {
188190

@@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
222224
object LikeSimplification extends Rule[LogicalPlan] {
223225
// if guards below protect from escapes on trailing %.
224226
// Cases like "something\%" are not optimized, but this does not affect correctness.
225-
val startsWith = "([^_%]+)%".r
226-
val endsWith = "%([^_%]+)".r
227-
val contains = "%([^_%]+)%".r
228-
val equalTo = "([^_%]*)".r
227+
private val startsWith = "([^_%]+)%".r
228+
private val endsWith = "%([^_%]+)".r
229+
private val contains = "%([^_%]+)%".r
230+
private val equalTo = "([^_%]*)".r
229231

230232
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
231233
case Like(l, Literal(utf, StringType)) =>
@@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
497499
grandChild))
498500
}
499501

500-
def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = {
502+
private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = {
501503
condition transform {
502504
case a: AttributeReference => sourceAliases.getOrElse(a, a)
503505
}
@@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
682684
import Decimal.MAX_LONG_DIGITS
683685

684686
/** Maximum number of decimal digits representable precisely in a Double */
685-
val MAX_DOUBLE_DIGITS = 15
687+
private val MAX_DOUBLE_DIGITS = 15
686688

687689
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
688690
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
2525
import org.apache.spark.sql.catalyst.dsl.expressions._
2626

27-
class UnionPushdownSuite extends PlanTest {
27+
class UnionPushDownSuite extends PlanTest {
2828
object Optimize extends RuleExecutor[LogicalPlan] {
2929
val batches =
3030
Batch("Subqueries", Once,
3131
EliminateSubQueries) ::
3232
Batch("Union Pushdown", Once,
33-
UnionPushdown) :: Nil
33+
UnionPushDown) :: Nil
3434
}
3535

3636
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
858858
experimental.extraStrategies ++ (
859859
DataSourceStrategy ::
860860
DDLStrategy ::
861-
TakeOrdered ::
861+
TakeOrderedAndProject ::
862862
HashAggregation ::
863863
LeftSemiJoin ::
864864
HashJoin ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
169169
log.debug(
170170
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
171171
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
172-
173172
GenerateMutableProjection.generate(expressions, inputSchema)
174173
} else {
175174
() => new InterpretedMutableProjection(expressions, inputSchema)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
213213
protected lazy val singleRowRdd =
214214
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
215215

216-
object TakeOrdered extends Strategy {
216+
object TakeOrderedAndProject extends Strategy {
217217
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
218218
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
219-
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
219+
execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
220+
case logical.Limit(
221+
IntegerLiteral(limit),
222+
logical.Project(projectList, logical.Sort(order, true, child))) =>
223+
execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
220224
case _ => Nil
221225
}
222226
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
3939
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
4040

4141
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
42-
val resuableProjection = buildProjection()
43-
iter.map(resuableProjection)
42+
val reusableProjection = buildProjection()
43+
iter.map(reusableProjection)
4444
}
4545

4646
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -147,21 +147,32 @@ case class Limit(limit: Int, child: SparkPlan)
147147

148148
/**
149149
* :: DeveloperApi ::
150-
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
151-
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
152-
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
150+
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
151+
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
152+
* or having a [[Project]] operator between them.
153+
* This could have been named TopK, but Spark's top operator does the opposite in ordering
154+
* so we name it TakeOrdered to avoid confusion.
153155
*/
154156
@DeveloperApi
155-
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
157+
case class TakeOrderedAndProject(
158+
limit: Int,
159+
sortOrder: Seq[SortOrder],
160+
projectList: Option[Seq[NamedExpression]],
161+
child: SparkPlan) extends UnaryNode {
156162

157163
override def output: Seq[Attribute] = child.output
158164

159165
override def outputPartitioning: Partitioning = SinglePartition
160166

161167
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
162168

163-
private def collectData(): Array[InternalRow] =
164-
child.execute().map(_.copy()).takeOrdered(limit)(ord)
169+
// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
170+
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
171+
172+
private def collectData(): Array[InternalRow] = {
173+
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
174+
projection.map(data.map(_)).getOrElse(data)
175+
}
165176

166177
override def executeCollect(): Array[Row] = {
167178
val converter = CatalystTypeConverters.createToScalaConverter(schema)

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {
141141

142142
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
143143
}
144+
145+
test("efficient limit -> project -> sort") {
146+
val query = testData.sort('key).select('value).limit(2).logicalPlan
147+
val planned = planner.TakeOrderedAndProject(query)
148+
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
149+
}
144150
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
442442
HiveCommandStrategy(self),
443443
HiveDDLStrategy,
444444
DDLStrategy,
445-
TakeOrdered,
445+
TakeOrderedAndProject,
446446
ParquetOperations,
447447
InMemoryScans,
448448
ParquetConversion, // Must be before HiveTableScans

0 commit comments

Comments
 (0)