Skip to content

Commit 697e17b

Browse files
committed
Replace RoutingTableMessage with pair
RoutingTableMessage was used to construct routing tables to enable joining VertexRDDs with partitioned edges. It stored three elements: the destination vertex ID, the source edge partition, and a byte specifying the position in which the edge partition referenced the vertex to enable join elimination. However, this was incompatible with sort-based shuffle (SPARK-2045). It was also slightly wasteful, because partition IDs are usually much smaller than 2^32, though this was mitigated by a custom serializer that used variable-length encoding. This commit replaces RoutingTableMessage with a pair of (VertexId, Int) where the Int encodes both the source partition ID (in the lower 30 bits) and the position (in the top 2 bits).
1 parent f776bc9 commit 697e17b

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
3535

3636
def registerClasses(kryo: Kryo) {
3737
kryo.register(classOf[Edge[Object]])
38-
kryo.register(classOf[RoutingTableMessage])
3938
kryo.register(classOf[(VertexId, Object)])
4039
kryo.register(classOf[EdgePartition[Object, Object]])
4140
kryo.register(classOf[BitSet])

graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
2727
import org.apache.spark.graphx._
2828
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
2929

30-
/**
31-
* A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
32-
* the edge partition references `vid` in the specified `position` (src, dst, or both).
33-
*/
34-
private[graphx]
35-
class RoutingTableMessage(
36-
var vid: VertexId,
37-
var pid: PartitionID,
38-
var position: Byte)
39-
extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
40-
override def _1 = vid
41-
override def _2 = (pid, position)
42-
override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
43-
}
30+
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
4431

4532
private[graphx]
4633
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
4734
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
4835
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
49-
new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
36+
new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
5037
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
5138
}
5239
}
@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
6249

6350
private[graphx]
6451
object RoutingTablePartition {
52+
/**
53+
* A message from an edge partition to a vertex specifying the position in which the edge
54+
* partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
55+
* 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
56+
*/
57+
type RoutingTableMessage = (VertexId, Int)
58+
59+
private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
60+
val positionUpper2 = position << 30
61+
val pidLower30 = pid & 0x3FFFFFFF
62+
(vid, positionUpper2 | pidLower30)
63+
}
64+
65+
private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
66+
private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
67+
private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
68+
6569
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
6670

6771
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
@@ -77,7 +81,9 @@ object RoutingTablePartition {
7781
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
7882
}
7983
map.iterator.map { vidAndPosition =>
80-
new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
84+
val vid = vidAndPosition._1
85+
val position = vidAndPosition._2
86+
toMessage(vid, pid, position)
8187
}
8288
}
8389

@@ -88,9 +94,12 @@ object RoutingTablePartition {
8894
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
8995
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
9096
for (msg <- iter) {
91-
pid2vid(msg.pid) += msg.vid
92-
srcFlags(msg.pid) += (msg.position & 0x1) != 0
93-
dstFlags(msg.pid) += (msg.position & 0x2) != 0
97+
val vid = vidFromMessage(msg)
98+
val pid = pidFromMessage(msg)
99+
val position = positionFromMessage(msg)
100+
pid2vid(pid) += vid
101+
srcFlags(pid) += (position & 0x1) != 0
102+
dstFlags(pid) += (position & 0x2) != 0
94103
}
95104

96105
new RoutingTablePartition(pid2vid.zipWithIndex.map {

graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
2424

2525
import scala.reflect.ClassTag
2626

27-
import org.apache.spark.graphx._
2827
import org.apache.spark.serializer._
2928

29+
import org.apache.spark.graphx._
30+
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
31+
3032
private[graphx]
3133
class RoutingTableMessageSerializer extends Serializer with Serializable {
3234
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
3537
new ShuffleSerializationStream(s) {
3638
def writeObject[T: ClassTag](t: T): SerializationStream = {
3739
val msg = t.asInstanceOf[RoutingTableMessage]
38-
writeVarLong(msg.vid, optimizePositive = false)
39-
writeUnsignedVarInt(msg.pid)
40-
// TODO: Write only the bottom two bits of msg.position
41-
s.write(msg.position)
40+
writeVarLong(msg._1, optimizePositive = false)
41+
writeInt(msg._2)
4242
this
4343
}
4444
}
@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
4747
new ShuffleDeserializationStream(s) {
4848
override def readObject[T: ClassTag](): T = {
4949
val a = readVarLong(optimizePositive = false)
50-
val b = readUnsignedVarInt()
51-
val c = s.read()
52-
if (c == -1) throw new EOFException
53-
new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
50+
val b = readInt()
51+
(a, b).asInstanceOf[T]
5452
}
5553
}
5654
}

graphx/src/main/scala/org/apache/spark/graphx/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ package object graphx {
3030
*/
3131
type VertexId = Long
3232

33-
/** Integer identifer of a graph partition. */
33+
/** Integer identifer of a graph partition. Must be less than 2^30. */
3434
// TODO: Consider using Char.
3535
type PartitionID = Int
3636

0 commit comments

Comments
 (0)