Skip to content

Commit 8d85359

Browse files
ankurdaverxin
authored andcommitted
[SPARK-1552] Fix type comparison bug in {map,outerJoin}Vertices
In GraphImpl, mapVertices and outerJoinVertices use a more efficient implementation when the map function conserves vertex attribute types. This is implemented by comparing the ClassTags of the old and new vertex attribute types. However, ClassTags store erased types, so the comparison will return a false positive for types with different type parameters, such as Option[Int] and Option[Double]. This PR resolves the problem by requesting that the compiler generate evidence of equality between the old and new vertex attribute types, and providing a default value for the evidence parameter if the two types are not equal. The methods can then check the value of the evidence parameter to see whether the types are equal. It also adds a test called "mapVertices changing type with same erased type" that failed before the PR and succeeds now. Callers of mapVertices and outerJoinVertices can no longer use a wildcard for a graph's VD type. To avoid "Error occurred in an application involving default arguments," they must bind VD to a type parameter, as this PR does for ShortestPaths and LabelPropagation. Author: Ankur Dave <[email protected]> Closes apache#967 from ankurdave/SPARK-1552 and squashes the following commits: 68a4fff [Ankur Dave] Undo conserve naming 7388705 [Ankur Dave] Remove unnecessary ClassTag for VD parameters a704e5f [Ankur Dave] Use type equality constraint with default argument 29a5ab7 [Ankur Dave] Add failing test f458c83 [Ankur Dave] Revert "[SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices" 16d6af8 [Ankur Dave] [SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices
1 parent 41db44c commit 8d85359

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
138138
* }}}
139139
*
140140
*/
141-
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
141+
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2)
142+
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED]
142143

143144
/**
144145
* Transforms each edge attribute in the graph using the map function. The map function is not
@@ -348,7 +349,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
348349
* }}}
349350
*/
350351
def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
351-
(mapFunc: (VertexId, VD, Option[U]) => VD2)
352+
(mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null)
352353
: Graph[VD2, ED]
353354

354355
/**

graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
104104
new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse())
105105
}
106106

107-
override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = {
108-
if (classTag[VD] equals classTag[VD2]) {
107+
override def mapVertices[VD2: ClassTag]
108+
(f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
109+
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
110+
// null if not
111+
if (eq != null) {
109112
vertices.cache()
110113
// The map preserves type, so we can use incremental replication
111114
val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
@@ -232,8 +235,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
232235

233236
override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
234237
(other: RDD[(VertexId, U)])
235-
(updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = {
236-
if (classTag[VD] equals classTag[VD2]) {
238+
(updateF: (VertexId, VD, Option[U]) => VD2)
239+
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
240+
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
241+
// null if not
242+
if (eq != null) {
237243
vertices.cache()
238244
// updateF preserves type, so we can use incremental replication
239245
val newVerts = vertices.leftJoin(other)(updateF).cache()

graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ object LabelPropagation {
4141
*
4242
* @return a graph with vertex attributes containing the label of community affiliation
4343
*/
44-
def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = {
44+
def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
4545
val lpaGraph = graph.mapVertices { case (vid, _) => vid }
4646
def sendMessage(e: EdgeTriplet[VertexId, ED]) = {
4747
Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))

graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ object ShortestPaths {
4949
* @return a graph where each vertex attribute is a map containing the shortest-path distance to
5050
* each reachable landmark vertex.
5151
*/
52-
def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
52+
def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
5353
val spGraph = graph.mapVertices { (vid, attr) =>
5454
if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
5555
}

graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext {
159159
}
160160
}
161161

162+
test("mapVertices changing type with same erased type") {
163+
withSpark { sc =>
164+
val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])](
165+
(1L, Some(1)),
166+
(2L, Some(2)),
167+
(3L, Some(3))
168+
))
169+
val edges = sc.parallelize(Array(
170+
Edge(1L, 2L, 0),
171+
Edge(2L, 3L, 0),
172+
Edge(3L, 1L, 0)
173+
))
174+
val graph0 = Graph(vertices, edges)
175+
// Trigger initial vertex replication
176+
graph0.triplets.foreach(x => {})
177+
// Change type of replicated vertices, but preserve erased type
178+
val graph1 = graph0.mapVertices {
179+
case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double)
180+
}
181+
// Access replicated vertices, exposing the erased type
182+
val graph2 = graph1.mapTriplets(t => t.srcAttr.get)
183+
assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
184+
}
185+
}
186+
162187
test("mapEdges") {
163188
withSpark { sc =>
164189
val n = 3

0 commit comments

Comments
 (0)