diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index db73a8abc573..2c929e6ed38c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -329,6 +329,25 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def mask[VD2: ClassTag, ED2: ClassTag](other: Graph[VD2, ED2]): Graph[VD, ED] + /** + * Union current Graph with other Graph. + * The union of two graphs G(VG, EG) and H(VH, EH) is the union of their vertex sets + * and their edge families. Which means G u H = (VG u VH, EG u EH). + * @param other the other Graph will union + * @param mergeSameVertexAttr merge same vertex attribute function + * @param mergeSameEdgeAttr merge same edge attribute function + * @tparam VD2 other Graph Vertex Type + * @tparam ED2 other Graph Edge Type + * @tparam VD3 result joined Graph Vertex Type + * @tparam ED3 result joined Graph Edge Type + * @return a graph join withe the two graph's vertex and edge set + */ + def union[VD2: ClassTag, ED2: ClassTag, VD3: ClassTag, ED3: ClassTag]( + other: Graph[VD2, ED2], + mergeSameVertexAttr: (Option[VD], Option[VD2]) => VD3, + mergeSameEdgeAttr: (Option[ED], Option[ED2]) => ED3) + : Graph[VD3, ED3] + /** * Merges multiple edges between two vertices into a single edge. For correct results, the graph * must have been partitioned using [[partitionBy]]. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26c..7394389006b8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -178,6 +178,39 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(newVerts, replicatedVertexView.withEdges(newEdges)) } + override def union[VD2: ClassTag, ED2: ClassTag, VD3: ClassTag, ED3: ClassTag] ( + other: Graph[VD2, ED2], + mergeSameVertexAttr: (Option[VD], Option[VD2]) => VD3, + mergeSameEdgeAttr: (Option[ED], Option[ED2]) => ED3): Graph[VD3, ED3] = { + + val newVertexRDD: RDD[(VertexId, VD3)] = vertices.fullOuterJoin(other.vertices).map { + pair => (pair._1, mergeSameVertexAttr(pair._2._1, pair._2._2)) + }.cache() + + // convert other EdgeRDD to kv pair RDD + val otherPair = other.edges.mapPartitions { + iter => iter.map { edge => (edge.srcId.toString + edge.dstId.toString, edge) } + } + + // full out join the kv pair RDD + val joinedRDD: RDD[Edge[ED3]] = RDD.rddToPairRDDFunctions { + edges.mapPartitions { _.map(edge => (edge.srcId.toString + edge.dstId.toString, edge)) } + }.fullOuterJoin(otherPair).map { + f => { + val curEdge = f._2._1 + val otherEdge = f._2._2 + val edge = curEdge.getOrElse(otherEdge.get) + val curAttr = if (curEdge.isDefined) Some(curEdge.get.attr) else None + val otherAttr = if (otherEdge.isDefined) Some(otherEdge.get.attr) else None + Edge(edge.srcId, edge.dstId, mergeSameEdgeAttr(curAttr, otherAttr)) + } + } + + // convert to EdgeRDD and new Graph + val newEdgeRDD: EdgeRDDImpl[ED3, VD3] = EdgeRDD.fromEdges[ED3, VD3](joinedRDD).cache() + new GraphImpl(VertexRDD(newVertexRDD), new ReplicatedVertexView(newEdgeRDD)) + } + override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = { val newEdges = replicatedVertexView.edges.mapEdgePartitions( (pid, part) => part.groupEdges(merge)) diff --git a/graphx/src/test/resources/union_1_test.data b/graphx/src/test/resources/union_1_test.data new file mode 100644 index 000000000000..5c001da4afcf --- /dev/null +++ b/graphx/src/test/resources/union_1_test.data @@ -0,0 +1,16 @@ +1,2 +1,3 +1,4 +2,1 +2,3 +2,4 +3,1 +3,2 +3,4 +4,1 +4,2 +4,3 +5,4 +5,6 +6,5 +6,4 \ No newline at end of file diff --git a/graphx/src/test/resources/union_2_test.data b/graphx/src/test/resources/union_2_test.data new file mode 100644 index 000000000000..5e3fe2f1a995 --- /dev/null +++ b/graphx/src/test/resources/union_2_test.data @@ -0,0 +1,12 @@ +5,6 +5,7 +5,8 +6,5 +6,7 +6,8 +7,5 +7,6 +7,8 +8,5 +8,6 +8,7 \ No newline at end of file diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 1f5e27d5508b..c11ad5c908ae 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -428,4 +428,34 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } + test("union between two graph") { + withSpark { sc => + val rdd1 = sc.textFile(getClass.getResource("/union_1_test.data").getFile).zipWithIndex().map{ + line => + val fields = line._1.split(",") + Edge(fields(0).trim.toLong, fields(1).trim.toLong, line._2 + 1) + } + val rdd2 = sc.textFile(getClass.getResource("/union_2_test.data").getFile).zipWithIndex().map{ + line => + val fields = line._1.split(",") + Edge(fields(0).trim.toLong, fields(1).trim.toLong, line._2 + 1) + } + + val mergeVertex = (a: Option[Int], b: Option[Int]) => a.getOrElse(0) + b.getOrElse(0) + val mergeEdge = (a: Option[Long], b: Option[Long]) => a.getOrElse(0L) + b.getOrElse(0L) + + val graph1 = Graph.fromEdges(rdd1, 1) + val graph2 = Graph.fromEdges(rdd2, 2) + val graph3 = graph1.union(graph2, mergeVertex, mergeEdge) + + assert(graph1.edges.count() + graph2.edges.count - 2 == graph3.edges.count) + + val diff = (graph1.edges.collect() ++ graph2.edges.collect()).diff(graph3.edges.collect()) + assert(diff.count(p => p.srcId == 5L || p.srcId == 6L || p.dstId == 5L || p.dstId == 6L) == 4) + + val vdiff = graph3.vertices.collect().diff(graph1.vertices.collect ++ graph2.vertices.collect) + assert(vdiff.diff(Array((6, 3), (5, 3))).length == 0) + } + } + } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7a748fb5e38b..adec63e50100 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,7 +64,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.CatalystTimestampConverter"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter$") + "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), + // SPARK-7984 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.union") ) case v if v.startsWith("1.4") => Seq(