Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class ShuffleDependency[K, V, C](
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None)
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
*/
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
val part = new RangePartitioner(numPartitions, self, ascending)
val shuffled = new ShuffledRDD[K, V, P](self, part)
val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
Expand Down
20 changes: 5 additions & 15 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
self.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitionsWithContext((context, iter) => {
aggregator.combineValuesByKey(iter, context)
}, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializer)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
.setSerializer(serializer)
.setAggregator(aggregator)
.setMapSideCombine(mapSideCombine)
}
}

Expand Down Expand Up @@ -401,7 +391,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
if (self.partitioner == Some(partitioner)) {
self
} else {
new ShuffledRDD[K, V, (K, V)](self, partitioner)
new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ abstract class RDD[T: ClassTag](

// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
numPartitions).values
} else {
Expand Down
37 changes: 31 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.Serializer

Expand All @@ -35,23 +35,48 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
* @tparam C the combiner class.
*/
@DeveloperApi
class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
@transient var prev: RDD[P],
class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
@transient var prev: RDD[_ <: Product2[K, V]],
part: Partitioner)
extends RDD[P](prev.context, Nil) {

private var serializer: Option[Serializer] = None

private var keyOrdering: Option[Ordering[K]] = None

private var aggregator: Option[Aggregator[K, V, C]] = None

private var mapSideCombine: Boolean = false

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
this.serializer = Option(serializer)
this
}

/** Set key ordering for RDD's shuffle. */
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
this.keyOrdering = Option(keyOrdering)
this
}

/** Set aggregator for RDD's shuffle. */
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
this.aggregator = Option(aggregator)
this
}

/** Set mapSideCombine flag for RDD's shuffle. */
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
this.mapSideCombine = mapSideCombine
this
}

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializer))
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}

override val partitioner = Some(part)
Expand All @@ -61,7 +86,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ private[spark] class ShuffleMapTask(
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
writer.stop(success = true).get
writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
return writer.stop(success = true).get
} catch {
case e: Exception =>
if (writer != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.scheduler.MapStatus
* Obtained inside a map task to write out records to the shuffle system.
*/
private[spark] trait ShuffleWriter[K, V] {
/** Write a record to this task's output */
def write(record: Product2[K, V]): Unit
/** Write a bunch of records to this task's output */
def write(records: Iterator[_ <: Product2[K, V]]): Unit

/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.TaskContext

class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
Expand All @@ -31,10 +31,24 @@ class HashShuffleReader[K, C](
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")

private val dep = handle.dependency

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(handle.dependency.serializer))
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(dep.serializer))

if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
iter
}
}

/** Close this reader */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,24 @@ class HashShuffleWriter[K, V](
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)

/** Write a record to this task's output */
override def write(record: Product2[K, V]): Unit = {
val pair = record.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
} else {
records
}
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
records
}

for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
shuffle.writers(bucketId).write(elem)
}
}

/** Close this writer, passing along whether the map completed */
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ShuffledRDD") {
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
})
}

Expand Down
22 changes: 14 additions & 8 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
NonJavaSerializableClass,
(Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should split out the call to setSerializer into a new statement instead of chaining it. (Just do c.setSerializer(...).)

c.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId

assert(c.count === 10)
Expand All @@ -78,8 +81,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf))
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
NonJavaSerializableClass,
(Int, NonJavaSerializableClass)](b, new HashPartitioner(3))
c.setSerializer(new KryoSerializer(conf))
assert(c.count === 10)
}

Expand All @@ -94,7 +100,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {

// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))

val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
Expand All @@ -120,7 +126,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val b = a.map(x => (x, x*2))

// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))

val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
Expand All @@ -141,8 +147,8 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2))
.collect()
val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs,
new HashPartitioner(2)).collect()

data.foreach { pair => results should contain (pair) }
}
Expand Down
14 changes: 7 additions & 7 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// we can optionally shuffle to keep the upstream parallel
val coalesced5 = data.coalesce(1, shuffle = true)
val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd.
asInstanceOf[ShuffledRDD[_, _, _]] != null
asInstanceOf[ShuffledRDD[_, _, _, _]] != null
assert(isEquals)

// when shuffling, we can increase the number of partitions
Expand Down Expand Up @@ -509,7 +509,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("takeSample") {
val n = 1000000
val data = sc.parallelize(1 to n, 2)

for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
Expand Down Expand Up @@ -704,11 +704,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)

// Any ancestors before the shuffle are not considered
assert(ancestors4.size === 1)
assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
assert(ancestors5.size === 4)
assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 1)
assert(ancestors4.size === 0)
assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0)
assert(ancestors5.size === 3)
assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1)
assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0)
assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {2} // Shuffle map stage + result stage
val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get
stageInfo3.rddInfos.size should be {2} // ShuffledRDD, MapPartitionsRDD
stageInfo3.rddInfos.size should be {1} // ShuffledRDD
stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRe
private[graphx]
class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
val rdd = new ShuffledRDD[PartitionID, (VertexId, T), VertexBroadcastMsg[T]](self, partitioner)
val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]](
self, partitioner)

// Set a custom serializer if the data is of int or double type.
if (classTag[T] == ClassTag.Int) {
Expand All @@ -84,7 +85,7 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {
* Return a copy of the RDD partitioned using the specified partitioner.
*/
def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner)
new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner)
}

}
Expand All @@ -103,7 +104,7 @@ object MsgRDDFunctions {
private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner)
val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner)

// Set a custom serializer if the data is of int or double type.
if (classTag[VD] == ClassTag.Int) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
new ShuffledRDD[VertexId, (PartitionID, Byte), RoutingTableMessage](self, partitioner)
.setSerializer(new RoutingTableMessageSerializer)
new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
}
}

Expand Down
Loading