Skip to content

Commit 2d25e34

Browse files
ankurdaverxin
authored andcommitted
Replace RoutingTableMessage with pair
RoutingTableMessage was used to construct routing tables to enable joining VertexRDDs with partitioned edges. It stored three elements: the destination vertex ID, the source edge partition, and a byte specifying the position in which the edge partition referenced the vertex to enable join elimination. However, this was incompatible with sort-based shuffle (SPARK-2045). It was also slightly wasteful, because partition IDs are usually much smaller than 2^32, though this was mitigated by a custom serializer that used variable-length encoding. This commit replaces RoutingTableMessage with a pair of (VertexId, Int) where the Int encodes both the source partition ID (in the lower 30 bits) and the position (in the top 2 bits). Author: Ankur Dave <[email protected]> Closes apache#1553 from ankurdave/remove-RoutingTableMessage and squashes the following commits: 697e17b [Ankur Dave] Replace RoutingTableMessage with pair
1 parent 60f0ae3 commit 2d25e34

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
3535

3636
def registerClasses(kryo: Kryo) {
3737
kryo.register(classOf[Edge[Object]])
38-
kryo.register(classOf[RoutingTableMessage])
3938
kryo.register(classOf[(VertexId, Object)])
4039
kryo.register(classOf[EdgePartition[Object, Object]])
4140
kryo.register(classOf[BitSet])

graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
2727
import org.apache.spark.graphx._
2828
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
2929

30-
/**
31-
* A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
32-
* the edge partition references `vid` in the specified `position` (src, dst, or both).
33-
*/
34-
private[graphx]
35-
class RoutingTableMessage(
36-
var vid: VertexId,
37-
var pid: PartitionID,
38-
var position: Byte)
39-
extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
40-
override def _1 = vid
41-
override def _2 = (pid, position)
42-
override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
43-
}
30+
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
4431

4532
private[graphx]
4633
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
4734
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
4835
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
49-
new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
36+
new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
5037
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
5138
}
5239
}
@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
6249

6350
private[graphx]
6451
object RoutingTablePartition {
52+
/**
53+
* A message from an edge partition to a vertex specifying the position in which the edge
54+
* partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
55+
* 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
56+
*/
57+
type RoutingTableMessage = (VertexId, Int)
58+
59+
private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
60+
val positionUpper2 = position << 30
61+
val pidLower30 = pid & 0x3FFFFFFF
62+
(vid, positionUpper2 | pidLower30)
63+
}
64+
65+
private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
66+
private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
67+
private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
68+
6569
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
6670

6771
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
@@ -77,7 +81,9 @@ object RoutingTablePartition {
7781
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
7882
}
7983
map.iterator.map { vidAndPosition =>
80-
new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
84+
val vid = vidAndPosition._1
85+
val position = vidAndPosition._2
86+
toMessage(vid, pid, position)
8187
}
8288
}
8389

@@ -88,9 +94,12 @@ object RoutingTablePartition {
8894
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
8995
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
9096
for (msg <- iter) {
91-
pid2vid(msg.pid) += msg.vid
92-
srcFlags(msg.pid) += (msg.position & 0x1) != 0
93-
dstFlags(msg.pid) += (msg.position & 0x2) != 0
97+
val vid = vidFromMessage(msg)
98+
val pid = pidFromMessage(msg)
99+
val position = positionFromMessage(msg)
100+
pid2vid(pid) += vid
101+
srcFlags(pid) += (position & 0x1) != 0
102+
dstFlags(pid) += (position & 0x2) != 0
94103
}
95104

96105
new RoutingTablePartition(pid2vid.zipWithIndex.map {

graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
2424

2525
import scala.reflect.ClassTag
2626

27-
import org.apache.spark.graphx._
2827
import org.apache.spark.serializer._
2928

29+
import org.apache.spark.graphx._
30+
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
31+
3032
private[graphx]
3133
class RoutingTableMessageSerializer extends Serializer with Serializable {
3234
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
3537
new ShuffleSerializationStream(s) {
3638
def writeObject[T: ClassTag](t: T): SerializationStream = {
3739
val msg = t.asInstanceOf[RoutingTableMessage]
38-
writeVarLong(msg.vid, optimizePositive = false)
39-
writeUnsignedVarInt(msg.pid)
40-
// TODO: Write only the bottom two bits of msg.position
41-
s.write(msg.position)
40+
writeVarLong(msg._1, optimizePositive = false)
41+
writeInt(msg._2)
4242
this
4343
}
4444
}
@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
4747
new ShuffleDeserializationStream(s) {
4848
override def readObject[T: ClassTag](): T = {
4949
val a = readVarLong(optimizePositive = false)
50-
val b = readUnsignedVarInt()
51-
val c = s.read()
52-
if (c == -1) throw new EOFException
53-
new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
50+
val b = readInt()
51+
(a, b).asInstanceOf[T]
5452
}
5553
}
5654
}

graphx/src/main/scala/org/apache/spark/graphx/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ package object graphx {
3030
*/
3131
type VertexId = Long
3232

33-
/** Integer identifer of a graph partition. */
33+
/** Integer identifer of a graph partition. Must be less than 2^30. */
3434
// TODO: Consider using Char.
3535
type PartitionID = Int
3636

0 commit comments

Comments
 (0)