Skip to content

Commit 3676a82

Browse files
committed
address comments
1 parent b558549 commit 3676a82

File tree

8 files changed

+54
-100
lines changed

8 files changed

+54
-100
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ object DefaultOptimizer extends Optimizer {
4141
Batch("Operator Optimizations", FixedPoint(100),
4242
// Operator push down
4343
UnionPushDown,
44-
LimitPushDown,
4544
PushPredicateThroughJoin,
4645
PushPredicateThroughProject,
4746
PushPredicateThroughGenerate,
@@ -112,20 +111,6 @@ object UnionPushDown extends Rule[LogicalPlan] {
112111
}
113112
}
114113

115-
object LimitPushDown extends Rule[LogicalPlan] {
116-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
117-
// Push down limit when the child is project on limit.
118-
case Limit(expr, Project(projectList, l: Limit)) =>
119-
Project(projectList, Limit(expr, l))
120-
121-
// Push down limit when the child is project on sort,
122-
// and we cannot push down this project through sort.
123-
case Limit(expr, p @ Project(projectList, s: Sort))
124-
if !s.references.subsetOf(p.outputSet) =>
125-
Project(projectList, Limit(expr, s))
126-
}
127-
}
128-
129114
/**
130115
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
131116
* transformations:
@@ -175,7 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
175160

176161
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
177162

178-
// push down project if possible when the child is sort
163+
// Push down project through limit, so that we may have chance to push it further.
164+
case Project(projectList, Limit(exp, child)) =>
165+
Limit(exp, Project(projectList, child))
166+
167+
// Push down project if possible when the child is sort
179168
case p @ Project(projectList, s @ Sort(_, _, grandChild))
180169
if s.references.subsetOf(p.outputSet) =>
181170
s.copy(child = Project(projectList, grandChild))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ class FilterPushdownSuite extends PlanTest {
9595
comparePlans(optimized, correctAnswer)
9696
}
9797

98+
test("column pruning for Project(ne, Limit)") {
99+
val originalQuery =
100+
testRelation
101+
.select('a, 'b)
102+
.limit(2)
103+
.select('a)
104+
105+
val optimized = Optimize.execute(originalQuery.analyze)
106+
val correctAnswer =
107+
testRelation
108+
.select('a)
109+
.limit(2).analyze
110+
111+
comparePlans(optimized, correctAnswer)
112+
}
113+
98114
// After this line is unimplemented.
99115
test("simple push down") {
100116
val originalQuery =

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala

Lines changed: 0 additions & 72 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
858858
experimental.extraStrategies ++ (
859859
DataSourceStrategy ::
860860
DDLStrategy ::
861-
TakeOrdered ::
861+
TakeOrderedAndProject ::
862862
HashAggregation ::
863863
LeftSemiJoin ::
864864
HashJoin ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
213213
protected lazy val singleRowRdd =
214214
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
215215

216-
object TakeOrdered extends Strategy {
216+
object TakeOrderedAndProject extends Strategy {
217217
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
218218
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
219-
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
219+
execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
220+
case logical.Limit(
221+
IntegerLiteral(limit),
222+
logical.Project(projectList, logical.Sort(order, true, child))) =>
223+
execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
220224
case _ => Nil
221225
}
222226
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,41 @@ case class Limit(limit: Int, child: SparkPlan)
147147

148148
/**
149149
* :: DeveloperApi ::
150-
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
151-
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
152-
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
150+
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
151+
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
152+
* or having a [[Project]] operator between them.
153+
* This could have been named TopK, but Spark's top operator does the opposite in ordering
154+
* so we name it TakeOrdered to avoid confusion.
153155
*/
154156
@DeveloperApi
155-
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
157+
case class TakeOrderedAndProject(
158+
limit: Int,
159+
sortOrder: Seq[SortOrder],
160+
projectList: Option[Seq[NamedExpression]],
161+
child: SparkPlan) extends UnaryNode {
156162

157163
override def output: Seq[Attribute] = child.output
158164

159165
override def outputPartitioning: Partitioning = SinglePartition
160166

161167
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
162168

163-
private def collectData(): Array[InternalRow] =
164-
child.execute().map(_.copy()).takeOrdered(limit)(ord)
169+
private val projection = projectList.map(newProjection(_, child.output))
170+
171+
private def collectData(): Iterator[InternalRow] = {
172+
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord).toIterator
173+
projection.map(data.map(_)).getOrElse(data)
174+
}
165175

166176
override def executeCollect(): Array[Row] = {
167177
val converter = CatalystTypeConverters.createToScalaConverter(schema)
168-
collectData().map(converter(_).asInstanceOf[Row])
178+
collectData().map(converter(_).asInstanceOf[Row]).toArray
169179
}
170180

171181
// TODO: Terminal split should be implemented differently from non-terminal split.
172182
// TODO: Pick num splits based on |limit|.
173-
protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
183+
protected override def doExecute(): RDD[InternalRow] =
184+
sparkContext.makeRDD(collectData().toArray[InternalRow], 1)
174185

175186
override def outputOrdering: Seq[SortOrder] = sortOrder
176187
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {
141141

142142
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
143143
}
144+
145+
test("efficient limit -> project -> sort") {
146+
val query = testData.sort('key).select('value).limit(2).logicalPlan
147+
val planned = planner.TakeOrderedAndProject(query)
148+
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
149+
}
144150
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
442442
HiveCommandStrategy(self),
443443
HiveDDLStrategy,
444444
DDLStrategy,
445-
TakeOrdered,
445+
TakeOrderedAndProject,
446446
ParquetOperations,
447447
InMemoryScans,
448448
ParquetConversion, // Must be before HiveTableScans

0 commit comments

Comments
 (0)