Skip to content

Commit 56eb8af

Browse files
jerryshaomateiz
authored andcommitted
[SPARK-2124] Move aggregation into shuffle implementations
This PR is a sub-task of SPARK-2044 to move the execution of aggregation into shuffle implementations. I leave `CoGoupedRDD` and `SubtractedRDD` unchanged because they have their implementations of aggregation. I'm not sure is it suitable to change these two RDDs. Also I do not move sort related code of `OrderedRDDFunctions` into shuffle, this will be solved in another sub-task. Author: jerryshao <[email protected]> Closes #1064 from jerryshao/SPARK-2124 and squashes the following commits: 4a05a40 [jerryshao] Modify according to comments 1f7dcc8 [jerryshao] Style changes 50a2fd6 [jerryshao] Fix test suite issue after moving aggregator to Shuffle reader and writer 1a96190 [jerryshao] Code modification related to the ShuffledRDD 308f635 [jerryshao] initial works of move combiner to ShuffleManager's reader and writer
1 parent 51c8168 commit 56eb8af

File tree

17 files changed

+112
-64
lines changed

17 files changed

+112
-64
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ class ShuffleDependency[K, V, C](
6161
val partitioner: Partitioner,
6262
val serializer: Option[Serializer] = None,
6363
val keyOrdering: Option[Ordering[K]] = None,
64-
val aggregator: Option[Aggregator[K, V, C]] = None)
64+
val aggregator: Option[Aggregator[K, V, C]] = None,
65+
val mapSideCombine: Boolean = false)
6566
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
6667

6768
val shuffleId: Int = rdd.context.newShuffleId()

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
5757
*/
5858
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
5959
val part = new RangePartitioner(numPartitions, self, ascending)
60-
val shuffled = new ShuffledRDD[K, V, P](self, part)
60+
val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
6161
shuffled.mapPartitions(iter => {
6262
val buf = iter.toArray
6363
if (ascending) {

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
9090
self.mapPartitionsWithContext((context, iter) => {
9191
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
9292
}, preservesPartitioning = true)
93-
} else if (mapSideCombine) {
94-
val combined = self.mapPartitionsWithContext((context, iter) => {
95-
aggregator.combineValuesByKey(iter, context)
96-
}, preservesPartitioning = true)
97-
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
98-
.setSerializer(serializer)
99-
partitioned.mapPartitionsWithContext((context, iter) => {
100-
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
101-
}, preservesPartitioning = true)
10293
} else {
103-
// Don't apply map-side combiner.
104-
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
105-
values.mapPartitionsWithContext((context, iter) => {
106-
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
107-
}, preservesPartitioning = true)
94+
new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
95+
.setSerializer(serializer)
96+
.setAggregator(aggregator)
97+
.setMapSideCombine(mapSideCombine)
10898
}
10999
}
110100

@@ -401,7 +391,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
401391
if (self.partitioner == Some(partitioner)) {
402392
self
403393
} else {
404-
new ShuffledRDD[K, V, (K, V)](self, partitioner)
394+
new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
405395
}
406396
}
407397

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ abstract class RDD[T: ClassTag](
340340

341341
// include a shuffle step so that our upstream tasks are still distributed
342342
new CoalescedRDD(
343-
new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
343+
new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
344344
new HashPartitioner(numPartitions)),
345345
numPartitions).values
346346
} else {

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
22+
import org.apache.spark._
2323
import org.apache.spark.annotation.DeveloperApi
2424
import org.apache.spark.serializer.Serializer
2525

@@ -35,23 +35,48 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
3535
* @param part the partitioner used to partition the RDD
3636
* @tparam K the key class.
3737
* @tparam V the value class.
38+
* @tparam C the combiner class.
3839
*/
3940
@DeveloperApi
40-
class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
41-
@transient var prev: RDD[P],
41+
class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
42+
@transient var prev: RDD[_ <: Product2[K, V]],
4243
part: Partitioner)
4344
extends RDD[P](prev.context, Nil) {
4445

4546
private var serializer: Option[Serializer] = None
4647

48+
private var keyOrdering: Option[Ordering[K]] = None
49+
50+
private var aggregator: Option[Aggregator[K, V, C]] = None
51+
52+
private var mapSideCombine: Boolean = false
53+
4754
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
48-
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
55+
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
4956
this.serializer = Option(serializer)
5057
this
5158
}
5259

60+
/** Set key ordering for RDD's shuffle. */
61+
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
62+
this.keyOrdering = Option(keyOrdering)
63+
this
64+
}
65+
66+
/** Set aggregator for RDD's shuffle. */
67+
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
68+
this.aggregator = Option(aggregator)
69+
this
70+
}
71+
72+
/** Set mapSideCombine flag for RDD's shuffle. */
73+
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
74+
this.mapSideCombine = mapSideCombine
75+
this
76+
}
77+
5378
override def getDependencies: Seq[Dependency[_]] = {
54-
List(new ShuffleDependency(prev, part, serializer))
79+
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
5580
}
5681

5782
override val partitioner = Some(part)
@@ -61,7 +86,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
6186
}
6287

6388
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
64-
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
89+
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
6590
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
6691
.read()
6792
.asInstanceOf[Iterator[P]]

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,8 @@ private[spark] class ShuffleMapTask(
144144
try {
145145
val manager = SparkEnv.get.shuffleManager
146146
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
147-
for (elem <- rdd.iterator(split, context)) {
148-
writer.write(elem.asInstanceOf[Product2[Any, Any]])
149-
}
150-
writer.stop(success = true).get
147+
writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
148+
return writer.stop(success = true).get
151149
} catch {
152150
case e: Exception =>
153151
if (writer != null) {

core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import org.apache.spark.scheduler.MapStatus
2323
* Obtained inside a map task to write out records to the shuffle system.
2424
*/
2525
private[spark] trait ShuffleWriter[K, V] {
26-
/** Write a record to this task's output */
27-
def write(record: Product2[K, V]): Unit
26+
/** Write a bunch of records to this task's output */
27+
def write(records: Iterator[_ <: Product2[K, V]]): Unit
2828

2929
/** Close this writer, passing along whether the map completed */
3030
def stop(success: Boolean): Option[MapStatus]

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20+
import org.apache.spark.{InterruptibleIterator, TaskContext}
2021
import org.apache.spark.serializer.Serializer
2122
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
22-
import org.apache.spark.TaskContext
2323

2424
class HashShuffleReader[K, C](
2525
handle: BaseShuffleHandle[K, _, C],
@@ -31,10 +31,24 @@ class HashShuffleReader[K, C](
3131
require(endPartition == startPartition + 1,
3232
"Hash shuffle currently only supports fetching one partition")
3333

34+
private val dep = handle.dependency
35+
3436
/** Read the combined key-values for this reduce task */
3537
override def read(): Iterator[Product2[K, C]] = {
36-
BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
37-
Serializer.getSerializer(handle.dependency.serializer))
38+
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
39+
Serializer.getSerializer(dep.serializer))
40+
41+
if (dep.aggregator.isDefined) {
42+
if (dep.mapSideCombine) {
43+
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
44+
} else {
45+
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
46+
}
47+
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
48+
throw new IllegalStateException("Aggregator is empty for map-side combine")
49+
} else {
50+
iter
51+
}
3852
}
3953

4054
/** Close this reader */

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,24 @@ class HashShuffleWriter[K, V](
4040
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
4141
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
4242

43-
/** Write a record to this task's output */
44-
override def write(record: Product2[K, V]): Unit = {
45-
val pair = record.asInstanceOf[Product2[Any, Any]]
46-
val bucketId = dep.partitioner.getPartition(pair._1)
47-
shuffle.writers(bucketId).write(pair)
43+
/** Write a bunch of records to this task's output */
44+
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
45+
val iter = if (dep.aggregator.isDefined) {
46+
if (dep.mapSideCombine) {
47+
dep.aggregator.get.combineValuesByKey(records, context)
48+
} else {
49+
records
50+
}
51+
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
52+
throw new IllegalStateException("Aggregator is empty for map-side combine")
53+
} else {
54+
records
55+
}
56+
57+
for (elem <- iter) {
58+
val bucketId = dep.partitioner.getPartition(elem._1)
59+
shuffle.writers(bucketId).write(elem)
60+
}
4861
}
4962

5063
/** Close this writer, passing along whether the map completed */

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
9999
test("ShuffledRDD") {
100100
testRDD(rdd => {
101101
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
102-
new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
102+
new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
103103
})
104104
}
105105

0 commit comments

Comments
 (0)