Skip to content

Commit a704e5f

Browse files
committed
Use type equality constraint with default argument
1 parent 29a5ab7 commit a704e5f

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
128128
* }}}
129129
*
130130
*/
131-
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
131+
def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2)
132+
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED]
132133

133134
/**
134135
* Transforms each edge attribute in the graph using the map function. The map function is not
@@ -338,7 +339,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
338339
* }}}
339340
*/
340341
def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
341-
(mapFunc: (VertexId, VD, Option[U]) => VD2)
342+
(mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null)
342343
: Graph[VD2, ED]
343344

344345
/**

graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
100100
new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse())
101101
}
102102

103-
override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = {
104-
if (classTag[VD] equals classTag[VD2]) {
103+
override def mapVertices[VD2: ClassTag]
104+
(f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
105+
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
106+
// null if not
107+
if (eq != null) {
105108
vertices.cache()
106109
// The map preserves type, so we can use incremental replication
107110
val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
@@ -228,8 +231,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
228231

229232
override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
230233
(other: RDD[(VertexId, U)])
231-
(updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = {
232-
if (classTag[VD] equals classTag[VD2]) {
234+
(updateF: (VertexId, VD, Option[U]) => VD2)
235+
(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
236+
// The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
237+
// null if not
238+
if (eq != null) {
233239
vertices.cache()
234240
// updateF preserves type, so we can use incremental replication
235241
val newVerts = vertices.leftJoin(other)(updateF).cache()

graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ object LabelPropagation {
4141
*
4242
* @return a graph with vertex attributes containing the label of community affiliation
4343
*/
44-
def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = {
44+
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
4545
val lpaGraph = graph.mapVertices { case (vid, _) => vid }
4646
def sendMessage(e: EdgeTriplet[VertexId, ED]) = {
4747
Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))

graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ object ShortestPaths {
4949
* @return a graph where each vertex attribute is a map containing the shortest-path distance to
5050
* each reachable landmark vertex.
5151
*/
52-
def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
52+
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
5353
val spGraph = graph.mapVertices { (vid, attr) =>
5454
if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
5555
}

0 commit comments

Comments
 (0)