Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
TakeOrdered ::
TakeOrderedAndProject ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)

object TakeOrdered extends Strategy {
object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,35 @@ case class Limit(limit: Int, child: SparkPlan)

/**
* :: DeveloperApi ::
* Take the first limit elements as defined by the sortOrder. This is logically equivalent to
* having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
* This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
* or having a [[Project]] operator between them.
* This could have been named TopK, but Spark's top operator does the opposite in ordering
* so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
case class TakeOrderedAndProject(
limit: Int,
sortOrder: Seq[SortOrder],
projectList: Option[Seq[NamedExpression]],
child: SparkPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
override def output: Seq[Attribute] = {
val projectOutput = projectList.map(_.map(_.toAttribute))
projectOutput.getOrElse(child.output)
}

override def outputPartitioning: Partitioning = SinglePartition

private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)

private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord)
// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))

private def collectData(): Array[Row] = {
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
projection.map(data.map(_)).getOrElse(data)
}

override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
Expand All @@ -169,6 +184,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
protected override def doExecute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)

override def outputOrdering: Seq[SortOrder] = sortOrder

override def simpleString: String = {
val orderByString = sortOrder.mkString("[", ",", "]")
val outputString = output.mkString("[", ",", "]")

s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,24 @@ class PlannerSuite extends SparkFunSuite {

setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
}

test("efficient limit -> project -> sort") {
{
val query =
testData.select('key, 'value).sort('key).limit(2).logicalPlan
val planned = planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
assert(planned.head.output === testData.select('key, 'value).logicalPlan.output)
}

{
// We need to make sure TakeOrderedAndProject's output is correct when we push a project
// into it.
val query =
testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan
val planned = planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
assert(planned.head.output === testData.select('value, 'key).logicalPlan.output)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveCommandStrategy(self),
HiveDDLStrategy,
DDLStrategy,
TakeOrdered,
TakeOrderedAndProject,
ParquetOperations,
InMemoryScans,
ParquetConversion, // Must be before HiveTableScans
Expand Down