Skip to content

Commit c734987

Browse files
cloud-fanyhuai
authored andcommitted
[SPARK-7289] [SPARK-9949] [SQL] Backport SPARK-7289 and SPARK-9949 to branch 1.4
The bug fixed by SPARK-7289 is a pretty serious one (Spark SQL generates wrong results). We should backport the fix to branch 1.4 (#6780). Also, we need to backport the fix of `TakeOrderedAndProject` as well (#8179). Author: Wenchen Fan <[email protected]> Author: Yin Huai <[email protected]> Closes #8252 from yhuai/backport7289And9949.
1 parent f7f2ac6 commit c734987

File tree

5 files changed

+56
-10
lines changed

5 files changed

+56
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
827827
experimental.extraStrategies ++ (
828828
DataSourceStrategy ::
829829
DDLStrategy ::
830-
TakeOrdered ::
830+
TakeOrderedAndProject ::
831831
HashAggregation ::
832832
LeftSemiJoin ::
833833
HashJoin ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
205205
protected lazy val singleRowRdd =
206206
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
207207

208-
object TakeOrdered extends Strategy {
208+
object TakeOrderedAndProject extends Strategy {
209209
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
210210
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
211-
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
211+
execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
212+
case logical.Limit(
213+
IntegerLiteral(limit),
214+
logical.Project(projectList, logical.Sort(order, true, child))) =>
215+
execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
212216
case _ => Nil
213217
}
214218
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,35 @@ case class Limit(limit: Int, child: SparkPlan)
144144

145145
/**
146146
* :: DeveloperApi ::
147-
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
148-
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
149-
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
147+
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
148+
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
149+
* or having a [[Project]] operator between them.
150+
* This could have been named TopK, but Spark's top operator does the opposite in ordering
151+
* so we name it TakeOrdered to avoid confusion.
150152
*/
151153
@DeveloperApi
152-
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
154+
case class TakeOrderedAndProject(
155+
limit: Int,
156+
sortOrder: Seq[SortOrder],
157+
projectList: Option[Seq[NamedExpression]],
158+
child: SparkPlan) extends UnaryNode {
153159

154-
override def output: Seq[Attribute] = child.output
160+
override def output: Seq[Attribute] = {
161+
val projectOutput = projectList.map(_.map(_.toAttribute))
162+
projectOutput.getOrElse(child.output)
163+
}
155164

156165
override def outputPartitioning: Partitioning = SinglePartition
157166

158167
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
159168

160-
private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord)
169+
// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
170+
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
171+
172+
private def collectData(): Array[Row] = {
173+
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
174+
projection.map(data.map(_)).getOrElse(data)
175+
}
161176

162177
override def executeCollect(): Array[Row] = {
163178
val converter = CatalystTypeConverters.createToScalaConverter(schema)
@@ -169,6 +184,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
169184
protected override def doExecute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)
170185

171186
override def outputOrdering: Seq[SortOrder] = sortOrder
187+
188+
override def simpleString: String = {
189+
val orderByString = sortOrder.mkString("[", ",", "]")
190+
val outputString = output.mkString("[", ",", "]")
191+
192+
s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
193+
}
172194
}
173195

174196
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,24 @@ class PlannerSuite extends SparkFunSuite {
142142

143143
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
144144
}
145+
146+
test("efficient limit -> project -> sort") {
147+
{
148+
val query =
149+
testData.select('key, 'value).sort('key).limit(2).logicalPlan
150+
val planned = planner.TakeOrderedAndProject(query)
151+
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
152+
assert(planned.head.output === testData.select('key, 'value).logicalPlan.output)
153+
}
154+
155+
{
156+
// We need to make sure TakeOrderedAndProject's output is correct when we push a project
157+
// into it.
158+
val query =
159+
testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan
160+
val planned = planner.TakeOrderedAndProject(query)
161+
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
162+
assert(planned.head.output === testData.select('value, 'key).logicalPlan.output)
163+
}
164+
}
145165
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
447447
HiveCommandStrategy(self),
448448
HiveDDLStrategy,
449449
DDLStrategy,
450-
TakeOrdered,
450+
TakeOrderedAndProject,
451451
ParquetOperations,
452452
InMemoryScans,
453453
ParquetConversion, // Must be before HiveTableScans

0 commit comments

Comments
 (0)