From 948f7405b3a35638b7f0d949702817703de3dde7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Jun 2015 10:23:12 +0800 Subject: [PATCH 1/8] fix existing --- .../sql/catalyst/optimizer/Optimizer.scala | 55 +++++++++++-------- .../optimizer/LimitPushDownSuit.scala | 22 ++++++++ .../optimizer/UnionPushdownSuite.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 4 +- 4 files changed, 57 insertions(+), 28 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 98b4476076854..92f3402b9258c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -39,7 +39,7 @@ object DefaultOptimizer extends Optimizer { Batch("Distinct", FixedPoint(100), ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), - UnionPushdown, + UnionPushDown, CombineFilters, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -63,25 +63,25 @@ object DefaultOptimizer extends Optimizer { } /** - * Pushes operations to either side of a Union. - */ -object UnionPushdown extends Rule[LogicalPlan] { + * Pushes operations to either side of a Union. + */ +object UnionPushDown extends Rule[LogicalPlan] { /** - * Maps Attributes from the left side to the corresponding Attribute on the right side. - */ - def buildRewrites(union: Union): AttributeMap[Attribute] = { + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + private def buildRewrites(union: Union): AttributeMap[Attribute] = { assert(union.left.output.size == union.right.output.size) AttributeMap(union.left.output.zip(union.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. - */ - def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { val result = e transform { case a: Attribute => rewrites(a) } @@ -108,6 +108,17 @@ object UnionPushdown extends Rule[LogicalPlan] { } } +object LimitPushDown extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down limit when the child is project on limit + case Limit(expr, Project(projectList, l: Limit)) => + Project(projectList, Limit(expr, l)) + + // Push down limit when the child is project on sort + case Limit(expr, Project(projectList, s: Sort)) => + Project(projectList, Limit(expr, s)) + } +} /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following @@ -117,7 +128,6 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin - * - Performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -159,9 +169,6 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) - // push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => @@ -181,8 +188,8 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[Project]] operators into one, merging the - * expressions into one single expression. + * Combines two adjacent [[Project]] operators into one and perform alias substitution, + * merging the expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { @@ -222,10 +229,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] { object LikeSimplification extends Rule[LogicalPlan] { // if guards below protect from escapes on trailing %. // Cases like "something\%" are not optimized, but this does not affect correctness. - val startsWith = "([^_%]+)%".r - val endsWith = "%([^_%]+)".r - val contains = "%([^_%]+)%".r - val equalTo = "([^_%]*)".r + private val startsWith = "([^_%]+)%".r + private val endsWith = "%([^_%]+)".r + private val contains = "%([^_%]+)%".r + private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(utf, StringType)) => @@ -497,7 +504,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } @@ -682,7 +689,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ - val MAX_DOUBLE_DIGITS = 15 + private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala new file mode 100644 index 0000000000000..299feed1ba4a2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +class LimitPushDownSuit { + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index 35f50be46b76f..ec379489a6d1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class UnionPushdownSuite extends PlanTest { +class UnionPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - UnionPushdown) :: Nil + UnionPushDown) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7aedd630e3871..fbb9b52d0f92c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val resuableProjection = buildProjection() - iter.map(resuableProjection) + val reusableProjection = buildProjection() + iter.map(reusableProjection) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering From 2d8be83cf2aeb948b4f0fd15f2d978ec9b02c997 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Jun 2015 14:06:13 +0800 Subject: [PATCH 2/8] add LimitPushDown --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../optimizer/FilterPushdownSuite.scala | 16 ------ .../optimizer/LimitPushDownSuit.scala | 52 ++++++++++++++++++- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 92f3402b9258c..1b9aad80d4e74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -44,6 +44,7 @@ object DefaultOptimizer extends Optimizer { PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, + LimitPushDown, ProjectCollapsing, CombineLimits, NullPropagation, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ffdc673cdc455..fc78366e0c33f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -95,22 +95,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for Project(ne, Limit)") { - val originalQuery = - testRelation - .select('a, 'b) - .limit(2) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .limit(2).analyze - - comparePlans(optimized, correctAnswer) - } - // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala index 299feed1ba4a2..5d774efc63c62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala @@ -17,6 +17,56 @@ package org.apache.spark.sql.catalyst.optimizer -class LimitPushDownSuit { +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +class LimitPushDownSuit extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Limit PushDown", FixedPoint(10), + LimitPushDown, + CombineLimits, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int) + + test("push down limit when the child is project on limit") { + val originalQuery = + testRelation + .limit(10) + .select('a) + .limit(2) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .limit(2) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down limit when the child is project on sort") { + val originalQuery = + testRelation + .sortBy(SortOrder('a, Ascending)) + .select('a) + .limit(2) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .sortBy(SortOrder('a, Ascending)) + .limit(2) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } } From 214842b0a47edb4d043504ad32263a04fe993f98 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Jun 2015 14:57:47 +0800 Subject: [PATCH 3/8] fix style --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1b9aad80d4e74..92f79dab7be58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -82,7 +82,7 @@ object UnionPushDown extends Rule[LogicalPlan] { * This method relies on the fact that the output attributes of a union are always equal * to the left child's output. */ - private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) } @@ -505,7 +505,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } From b5585493966ea94c9926b878a5d87ed56f62e6e3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 13 Jun 2015 10:31:23 +0800 Subject: [PATCH 4/8] address comments --- .../sql/catalyst/optimizer/Optimizer.scala | 17 +++++++++++------ ...hDownSuit.scala => LimitPushDownSuite.scala} | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{LimitPushDownSuit.scala => LimitPushDownSuite.scala} (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 92f79dab7be58..57d25069ef548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -39,20 +39,23 @@ object DefaultOptimizer extends Optimizer { Batch("Distinct", FixedPoint(100), ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), + // Operator push down UnionPushDown, - CombineFilters, + LimitPushDown, + PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, - LimitPushDown, + // Operator combine ProjectCollapsing, + CombineFilters, CombineLimits, + // Constant folding NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, - PushPredicateThroughJoin, RemovePositive, SimplifyFilters, SimplifyCasts, @@ -111,12 +114,14 @@ object UnionPushDown extends Rule[LogicalPlan] { object LimitPushDown extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down limit when the child is project on limit + // Push down limit when the child is project on limit. case Limit(expr, Project(projectList, l: Limit)) => Project(projectList, Limit(expr, l)) - // Push down limit when the child is project on sort - case Limit(expr, Project(projectList, s: Sort)) => + // Push down limit when the child is project on sort, + // and we cannot push down this project through sort. + case Limit(expr, p @ Project(projectList, s: Sort)) + if !s.references.subsetOf(p.outputSet) => Project(projectList, Limit(expr, s)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala similarity index 98% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala index 5d774efc63c62..c8168b32eef94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuit.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class LimitPushDownSuit extends PlanTest { +class LimitPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = From 3676a821de960bfbea1422fdcd37a92ed857b89b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Jun 2015 01:01:23 +0800 Subject: [PATCH 5/8] address comments --- .../sql/catalyst/optimizer/Optimizer.scala | 21 ++---- .../optimizer/FilterPushdownSuite.scala | 16 +++++ .../optimizer/LimitPushDownSuite.scala | 72 ------------------- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 8 ++- .../spark/sql/execution/basicOperators.scala | 27 ++++--- .../spark/sql/execution/PlannerSuite.scala | 6 ++ .../apache/spark/sql/hive/HiveContext.scala | 2 +- 8 files changed, 54 insertions(+), 100 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 57d25069ef548..bfd24287c9645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -41,7 +41,6 @@ object DefaultOptimizer extends Optimizer { Batch("Operator Optimizations", FixedPoint(100), // Operator push down UnionPushDown, - LimitPushDown, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -112,20 +111,6 @@ object UnionPushDown extends Rule[LogicalPlan] { } } -object LimitPushDown extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down limit when the child is project on limit. - case Limit(expr, Project(projectList, l: Limit)) => - Project(projectList, Limit(expr, l)) - - // Push down limit when the child is project on sort, - // and we cannot push down this project through sort. - case Limit(expr, p @ Project(projectList, s: Sort)) - if !s.references.subsetOf(p.outputSet) => - Project(projectList, Limit(expr, s)) - } -} - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -175,7 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) - // push down project if possible when the child is sort + // Push down project through limit, so that we may have chance to push it further. + case Project(projectList, Limit(exp, child)) => + Limit(exp, Project(projectList, child)) + + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fc78366e0c33f..ffdc673cdc455 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -95,6 +95,22 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for Project(ne, Limit)") { + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala deleted file mode 100644 index c8168b32eef94..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - -class LimitPushDownSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Limit PushDown", FixedPoint(10), - LimitPushDown, - CombineLimits, - ConstantFolding, - BooleanSimplification) :: Nil - } - - val testRelation = LocalRelation('a.int) - - test("push down limit when the child is project on limit") { - val originalQuery = - testRelation - .limit(10) - .select('a) - .limit(2) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .limit(2) - .select('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("push down limit when the child is project on sort") { - val originalQuery = - testRelation - .sortBy(SortOrder('a, Ascending)) - .select('a) - .limit(2) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .sortBy(SortOrder('a, Ascending)) - .limit(2) - .select('a).analyze - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 04fc798bf3738..5708df82de12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1ff1cc224de8c..21912cf24933e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fbb9b52d0f92c..c97c2031d33ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -160,17 +166,22 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = - child.execute().map(_.copy()).takeOrdered(limit)(ord) + private val projection = projectList.map(newProjection(_, child.output)) + + private def collectData(): Iterator[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord).toIterator + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) - collectData().map(converter(_).asInstanceOf[Row]) + collectData().map(converter(_).asInstanceOf[Row]).toArray } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = + sparkContext.makeRDD(collectData().toArray[InternalRow], 1) override def outputOrdering: Seq[SortOrder] = sortOrder } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5854ab48db552..3dd24130af81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cf05c6c989655..8021f915bb821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans From 20821ec227522607454f28b651ec5b31e335fbb9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 15 Jun 2015 11:45:04 +0800 Subject: [PATCH 6/8] fix --- .../org/apache/spark/sql/execution/basicOperators.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index c97c2031d33ec..eeb2b2a92b1db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -168,20 +168,20 @@ case class TakeOrderedAndProject( private val projection = projectList.map(newProjection(_, child.output)) - private def collectData(): Iterator[InternalRow] = { - val data = child.execute().map(_.copy()).takeOrdered(limit)(ord).toIterator + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) projection.map(data.map(_)).getOrElse(data) } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) - collectData().map(converter(_).asInstanceOf[Row]).toArray + collectData().map(converter(_).asInstanceOf[Row]) } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. protected override def doExecute(): RDD[InternalRow] = - sparkContext.makeRDD(collectData().toArray[InternalRow], 1) + sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder } From 07d545694690af2da3f3c661aa812f7b55f245d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 21 Jun 2015 22:45:26 +0800 Subject: [PATCH 7/8] clean closure --- .../spark/sql/execution/basicOperators.scala | 3 +- .../hive/execution/InsertIntoHiveTable.scala | 31 ++++++++++--------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index eeb2b2a92b1db..b2671f5340bd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -180,8 +180,7 @@ case class TakeOrderedAndProject( // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = - sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 05f425f2b65f3..9d76fc870d320 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -48,16 +48,16 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifNotExists: Boolean) extends UnaryNode with HiveInspectors { - @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass - @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val catalog = sc.catalog + val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] + lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass + private lazy val hiveContext = new Context(sc.hiveconf) + private lazy val catalog = sc.catalog - private def newSerializer(tableDesc: TableDesc): Serializer = { + private val newSerializer = (tableDesc: TableDesc) => { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] serializer.initialize(null, tableDesc.getProperties) serializer - } + }: Serializer def output: Seq[Attribute] = child.output @@ -79,13 +79,10 @@ case class InsertIntoHiveTable( SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writeToFile _) - writerContainer.commitJob() - + val newSer = newSerializer + val schema = table.schema // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val serializer = newSerializer(fileSinkConf.getTableInfo) + val writeToFile = (context: TaskContext, iterator: Iterator[InternalRow]) => { val standardOI = ObjectInspectorUtils .getStandardObjectInspector( fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, @@ -106,12 +103,16 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row, table.schema) - .write(serializer.serialize(outputData, standardOI)) + .getLocalFileWriter(row, schema) + .write(newSer(fileSinkConf.getTableInfo).serialize(outputData, standardOI)) } writerContainer.close() - } + }: Unit + + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, sc.sparkContext.clean(writeToFile)) + writerContainer.commitJob() } /** From 34aa07bd601999222d95e51a5dc1365e38b0ad23 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Jun 2015 00:17:52 +0800 Subject: [PATCH 8/8] revert --- .../spark/sql/execution/SparkPlan.scala | 1 - .../spark/sql/execution/basicOperators.scala | 3 +- .../hive/execution/InsertIntoHiveTable.scala | 31 +++++++++---------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8d30294293c..47f56b2b7ebe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled && expressions.forall(_.isThreadSafe)) { - GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b2671f5340bd4..647c4ab5cb651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -166,7 +166,8 @@ case class TakeOrderedAndProject( private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private val projection = projectList.map(newProjection(_, child.output)) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 9d76fc870d320..05f425f2b65f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -48,16 +48,16 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifNotExists: Boolean) extends UnaryNode with HiveInspectors { - val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass - private lazy val hiveContext = new Context(sc.hiveconf) - private lazy val catalog = sc.catalog + @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] + @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private lazy val hiveContext = new Context(sc.hiveconf) + @transient private lazy val catalog = sc.catalog - private val newSerializer = (tableDesc: TableDesc) => { + private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] serializer.initialize(null, tableDesc.getProperties) serializer - }: Serializer + } def output: Seq[Attribute] = child.output @@ -79,10 +79,13 @@ case class InsertIntoHiveTable( SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val newSer = newSerializer - val schema = table.schema + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + // Note that this function is executed on executor side - val writeToFile = (context: TaskContext, iterator: Iterator[InternalRow]) => { + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, @@ -103,16 +106,12 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row, schema) - .write(newSer(fileSinkConf.getTableInfo).serialize(outputData, standardOI)) + .getLocalFileWriter(row, table.schema) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() - }: Unit - - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, sc.sparkContext.clean(writeToFile)) - writerContainer.commitJob() + } } /**