@@ -19,65 +19,88 @@ package org.apache.spark.sql.execution
1919
2020import scala .reflect .runtime .universe .TypeTag
2121
22- import org .apache .spark .rdd . RDD
23- import org .apache .spark .SparkContext
24-
22+ import org .apache .spark .{ HashPartitioner , SparkConf , SparkContext }
23+ import org .apache .spark .rdd .{ RDD , ShuffledRDD }
24+ import org . apache . spark . sql . catalyst . ScalaReflection
2525import org .apache .spark .sql .catalyst .errors ._
2626import org .apache .spark .sql .catalyst .expressions ._
2727import org .apache .spark .sql .catalyst .plans .physical .{OrderedDistribution , UnspecifiedDistribution }
28- import org .apache .spark .sql .catalyst .ScalaReflection
28+ import org .apache .spark .util .MutablePair
29+
2930
3031case class Project (projectList : Seq [NamedExpression ], child : SparkPlan ) extends UnaryNode {
31- def output = projectList.map(_.toAttribute)
32+ override def output = projectList.map(_.toAttribute)
3233
33- def execute () = child.execute().mapPartitions { iter =>
34+ override def execute () = child.execute().mapPartitions { iter =>
3435 @ transient val reusableProjection = new MutableProjection (projectList)
3536 iter.map(reusableProjection)
3637 }
3738}
3839
3940case class Filter (condition : Expression , child : SparkPlan ) extends UnaryNode {
40- def output = child.output
41+ override def output = child.output
4142
42- def execute () = child.execute().mapPartitions { iter =>
43+ override def execute () = child.execute().mapPartitions { iter =>
4344 iter.filter(condition.apply(_).asInstanceOf [Boolean ])
4445 }
4546}
4647
4748case class Sample (fraction : Double , withReplacement : Boolean , seed : Int , child : SparkPlan )
4849 extends UnaryNode {
4950
50- def output = child.output
51+ override def output = child.output
5152
5253 // TODO: How to pick seed?
53- def execute () = child.execute().sample(withReplacement, fraction, seed)
54+ override def execute () = child.execute().sample(withReplacement, fraction, seed)
5455}
5556
5657case class Union (children : Seq [SparkPlan ])(@ transient sc : SparkContext ) extends SparkPlan {
5758 // TODO: attributes output by union should be distinct for nullability purposes
58- def output = children.head.output
59- def execute () = sc.union(children.map(_.execute()))
59+ override def output = children.head.output
60+ override def execute () = sc.union(children.map(_.execute()))
6061
6162 override def otherCopyArgs = sc :: Nil
6263}
6364
64- case class StopAfter (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
65+ /**
66+ * Take the first limit elements. Note that the implementation is different depending on whether
67+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
68+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
69+ * invoked using execute, we first take the limit on each partition, and then repartition all the
70+ * data to a single partition to compute the global limit.
71+ */
72+ case class Limit (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
73+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
74+ // partition local limit -> exchange into one partition -> partition local limit again
75+
6576 override def otherCopyArgs = sc :: Nil
6677
67- def output = child.output
78+ override def output = child.output
6879
6980 override def executeCollect () = child.execute().map(_.copy()).take(limit)
7081
71- // TODO: Terminal split should be implemented differently from non-terminal split.
72- // TODO: Pick num splits based on |limit|.
73- def execute () = sc.makeRDD(executeCollect(), 1 )
82+ override def execute () = {
83+ val rdd = child.execute().mapPartitions { iter =>
84+ val mutablePair = new MutablePair [Boolean , Row ]()
85+ iter.take(limit).map(row => mutablePair.update(false , row))
86+ }
87+ val part = new HashPartitioner (1 )
88+ val shuffled = new ShuffledRDD [Boolean , Row , MutablePair [Boolean , Row ]](rdd, part)
89+ shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
90+ shuffled.mapPartitions(_.take(limit).map(_._2))
91+ }
7492}
7593
76- case class TopK (limit : Int , sortOrder : Seq [SortOrder ], child : SparkPlan )
77- (@ transient sc : SparkContext ) extends UnaryNode {
94+ /**
95+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
96+ * having a [[Limit ]] operator after a [[Sort ]] operator. This could have been named TopK, but
97+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
98+ */
99+ case class TakeOrdered (limit : Int , sortOrder : Seq [SortOrder ], child : SparkPlan )
100+ (@ transient sc : SparkContext ) extends UnaryNode {
78101 override def otherCopyArgs = sc :: Nil
79102
80- def output = child.output
103+ override def output = child.output
81104
82105 @ transient
83106 lazy val ordering = new RowOrdering (sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
86109
87110 // TODO: Terminal split should be implemented differently from non-terminal split.
88111 // TODO: Pick num splits based on |limit|.
89- def execute () = sc.makeRDD(executeCollect(), 1 )
112+ override def execute () = sc.makeRDD(executeCollect(), 1 )
90113}
91114
92115
@@ -101,15 +124,15 @@ case class Sort(
101124 @ transient
102125 lazy val ordering = new RowOrdering (sortOrder)
103126
104- def execute () = attachTree(this , " sort" ) {
127+ override def execute () = attachTree(this , " sort" ) {
105128 // TODO: Optimize sorting operation?
106129 child.execute()
107130 .mapPartitions(
108131 iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
109132 preservesPartitioning = true )
110133 }
111134
112- def output = child.output
135+ override def output = child.output
113136}
114137
115138object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
130153}
131154
132155case class ExistingRdd (output : Seq [Attribute ], rdd : RDD [Row ]) extends LeafNode {
133- def execute () = rdd
156+ override def execute () = rdd
134157}
135158
0 commit comments