From 43cf4a33667a4c7f061c285a579b1d94d71cd1bc Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 4 Oct 2014 18:08:39 -0700 Subject: [PATCH 1/3] Java API for GraphX Adds JavaGraph, JavaVertexRDD, JavaEdgeRDD, and tests for these classes. --- .../spark/api/java/function/Function4.java | 27 + .../apache/spark/api/java/JavaPairRDD.scala | 15 +- .../scala/org/apache/spark/graphx/Graph.scala | 2 +- .../spark/graphx/api/java/JavaEdgeRDD.scala | 94 +++ .../spark/graphx/api/java/JavaGraph.scala | 744 ++++++++++++++++++ .../spark/graphx/api/java/JavaVertexRDD.scala | 253 ++++++ .../graphx/api/java/PartitionStrategies.java | 36 + .../apache/spark/graphx/JavaEdgeRDDSuite.java | 166 ++++ .../apache/spark/graphx/JavaGraphSuite.java | 386 +++++++++ .../spark/graphx/JavaVertexRDDSuite.java | 260 ++++++ .../org/apache/spark/graphx/GraphSuite.scala | 2 +- 11 files changed, 1982 insertions(+), 3 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function4.java create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/api/java/PartitionStrategies.java create mode 100644 graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java create mode 100644 graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java create mode 100644 graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 0000000000000..d4485a96e9bf7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3, and T4 and returns an R. + */ +public interface Function4 extends Serializable { + public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e37f3acaf6e30..47f270a228b7b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -36,7 +36,8 @@ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, + Function3 => JFunction3, Function4 => JFunction4, PairFunction, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.storage.StorageLevel @@ -991,6 +992,18 @@ object JavaPairRDD { implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd + private[spark] + implicit def toScalaFunction4[T1, T2, T3, T4, R]( + fun: JFunction4[T1, T2, T3, T4, R]): Function4[T1, T2, T3, T4, R] = { + (x1: T1, x2: T2, x3: T3, x4: T4) => fun.call(x1, x2, x3, x4) + } + + private[spark] + implicit def toScalaFunction3[T1, T2, T3, R]( + fun: JFunction3[T1, T2, T3, R]): Function3[T1, T2, T3, R] = { + (x1: T1, x2: T2, x3: T3) => fun.call(x1, x2, x3) + } + private[spark] implicit def toScalaFunction2[T1, T2, R](fun: JFunction2[T1, T2, R]): Function2[T1, T2, R] = { (x: T1, x1: T2) => fun.call(x, x1) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 2c1b9518a3d16..b6ef4498c5743 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -490,7 +490,7 @@ object Graph { * @param vertexStorageLevel the desired storage level at which to cache the vertices if necessary * * @return a graph with edge attributes containing either the count of duplicate edges or 1 - * (if `uniqueEdges` is `None`) and vertex attributes containing the total degree of each vertex. + * (if `uniqueEdges` is `None`) and all vertex attributes set to `defaultValue`. */ def fromEdgeTuples[VD: ClassTag]( rawEdges: RDD[(VertexId, VertexId)], diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala new file mode 100644 index 0000000000000..0e9153d479fb6 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.api.java + +import java.lang.{Integer => JInt, Long => JLong, Boolean => JBoolean} + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, + Function3 => JFunction3, Function4 => JFunction4, _} +import org.apache.spark.graphx._ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaPairRDD._ + +class JavaEdgeRDD[ED, VD](override val rdd: EdgeRDD[ED, VD])( + implicit val edTag: ClassTag[ED], val vdTag: ClassTag[VD]) + extends JavaRDD[Edge[ED]](rdd) { + + /** + * Map the values in an edge partitioning preserving the structure but changing the values. + * + * @tparam ED2 the new edge value type + * @param f the function from an edge to a new edge value + * @return a new EdgeRDD containing the new edge values + */ + def mapValues[ED2](f: JFunction[Edge[ED], ED2]): JavaEdgeRDD[ED2, VD] = { + implicit val ed2Tag: ClassTag[ED2] = fakeClassTag + rdd.mapValues(f) + } + + /** + * Reverse all the edges in this RDD. + * + * @return a new EdgeRDD containing all the edges reversed + */ + def reverse(): JavaEdgeRDD[ED, VD] = rdd.reverse + + /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ + def filter( + epred: JFunction[EdgeTriplet[VD, ED], JBoolean], + vpred: JFunction2[JLong, VD, JBoolean]): JavaEdgeRDD[ED, VD] = + rdd.filter(et => epred.call(et), (id, attr) => vpred.call(id, attr)) + + /** + * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same + * [[PartitionStrategy]]. + * + * @param other the EdgeRDD to join with + * @param f the join function applied to corresponding values of `this` and `other` + * @return a new EdgeRDD containing only edges that appear in both `this` and `other`, + * with values supplied by `f` + */ + def innerJoin[ED2, ED3]( + other: JavaEdgeRDD[ED2, _], + f: JFunction4[JLong, JLong, ED, ED2, ED3]): JavaEdgeRDD[ED3, VD] = { + implicit val ed2Tag: ClassTag[ED2] = fakeClassTag + implicit val ed3Tag: ClassTag[ED3] = fakeClassTag + rdd.innerJoin(other) { (src, dst, a, b) => f(src, dst, a, b) } + } +} + +object JavaEdgeRDD { + + def fromEdges[ED, VD](edges: JavaRDD[Edge[ED]]): JavaEdgeRDD[ED, VD] = { + implicit val edTag: ClassTag[ED] = fakeClassTag + implicit val vdTag: ClassTag[VD] = fakeClassTag + fromEdgeRDD(EdgeRDD.fromEdges(edges)) + } + + implicit def fromEdgeRDD[ED: ClassTag, VD: ClassTag](rdd: EdgeRDD[ED, VD]): JavaEdgeRDD[ED, VD] = + new JavaEdgeRDD[ED, VD](rdd) + + implicit def toEdgeRDD[ED, VD](rdd: JavaEdgeRDD[ED, VD]): EdgeRDD[ED, VD] = + rdd.rdd +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala new file mode 100644 index 0000000000000..c702ffa246436 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala @@ -0,0 +1,744 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.api.java + +import java.lang.{Integer => JInt, Long => JLong, Boolean => JBoolean} +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.JavaUtils +import com.google.common.base.Optional +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaRDD._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, + Function3 => JFunction3, _} +import org.apache.spark.graphx._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.rdd.RDD + +class JavaGraph[VD, ED](val graph: Graph[VD, ED])( + implicit val vdTag: ClassTag[VD], implicit val edTag: ClassTag[ED]) extends Serializable { + + /** + * An RDD containing the vertices and their associated attributes. + * + * @note vertex ids are unique. + * @return an RDD containing the vertices in this graph + */ + @transient val vertices: JavaVertexRDD[VD] = new JavaVertexRDD[VD](graph.vertices) + + /** + * An RDD containing the edges and their associated attributes. The entries in the RDD contain + * just the source id and target id along with the edge data. + * + * @return an RDD containing the edges in this graph + * + * @see [[Edge]] for the edge type. + * @see [[triplets]] to get an RDD which contains all the edges + * along with their vertex data. + * + */ + @transient val edges: JavaEdgeRDD[ED, VD] = new JavaEdgeRDD[ED, VD](graph.edges) + + /** + * An RDD containing the edge triplets, which are edges along with the vertex data associated with + * the adjacent vertices. The caller should use [[edges]] if the vertex data are not needed, i.e. + * if only the edge data and adjacent vertex ids are needed. + * + * @return an RDD containing edge triplets + * + * @example This operation might be used to evaluate a graph + * coloring where we would like to check that both vertices are a + * different color. + * {{{ + * type Color = Int + * val graph: Graph[Color, Int] = GraphLoader.edgeListFile("hdfs://file.tsv") + * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum + * }}} + */ + @transient val triplets: JavaRDD[EdgeTriplet[VD, ED]] = graph.triplets + + /** + * Caches the vertices and edges associated with this graph at the specified storage level, + * ignoring any target storage levels previously set. + * + * @param newLevel the level at which to cache the graph. + * + * @return A reference to this graph for convenience. + */ + def persist(newLevel: StorageLevel): JavaGraph[VD, ED] = graph.persist(newLevel) + + /** + * Caches the vertices and edges associated with this graph at the previously-specified target + * storage levels, which default to `MEMORY_ONLY`. This is used to pin a graph in memory enabling + * multiple queries to reuse the same construction process. + */ + def cache(): JavaGraph[VD, ED] = graph.cache() + + /** + * Repartitions the edges in the graph according to `partitionStrategy`. + * + * @param partitionStrategy the partitioning strategy to use when partitioning the edges + * in the graph. + */ + def partitionBy(partitionStrategy: PartitionStrategy): JavaGraph[VD, ED] = + graph.partitionBy(partitionStrategy) + + /** + * Repartitions the edges in the graph according to `partitionStrategy`. + * + * @param partitionStrategy the partitioning strategy to use when partitioning the edges + * in the graph. + * @param numPartitions the number of edge partitions in the new graph. + */ + def partitionBy(partitionStrategy: PartitionStrategy, numPartitions: Int): JavaGraph[VD, ED] = + graph.partitionBy(partitionStrategy, numPartitions) + + /** + * Transforms each vertex attribute in the graph using the map function. + * + * @note The new graph has the same structure. As a consequence the underlying index structures + * can be reused. + * + * @param map the function from a vertex object to a new vertex value + * + * @tparam VD2 the new vertex data type + * + * @example We might use this operation to change the vertex values + * from one type to another to initialize an algorithm. + * {{{ + * val rawGraph: Graph[(), ()] = Graph.textFile("hdfs://file") + * val root = 42 + * var bfsGraph = rawGraph.mapVertices[Int]((vid, data) => if (vid == root) 0 else Math.MaxValue) + * }}} + * + */ + def mapVertices[VD2](map: JFunction2[JLong, VD, VD2]): JavaGraph[VD2, ED] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + new JavaGraph(graph.mapVertices((id, attr) => map(id, attr))(fakeClassTag, eq = null)) + } + + /** + * Transforms each edge attribute in the graph using the map function. The map function is not + * passed the vertex value for the vertices adjacent to the edge. If vertex values are desired, + * use `mapTriplets`. + * + * @note This graph is not changed and that the new graph has the + * same structure. As a consequence the underlying index structures + * can be reused. + * + * @param map the function from an edge object to a new edge value. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes. + * + */ + def mapEdges[ED2](map: JFunction[Edge[ED], ED2]): JavaGraph[VD, ED2] = { + implicit val ed2Tag: ClassTag[ED2] = fakeClassTag + new JavaGraph(graph.mapEdges(map)) + } + + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2]( + map: JFunction[EdgeTriplet[VD, ED], ED2], + tripletFields: TripletFields): JavaGraph[VD, ED2] = { + implicit val ed2Tag: ClassTag[ED2] = fakeClassTag + new JavaGraph(graph.mapTriplets(map, tripletFields)) + } + + /** + * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned + * graph contains an edge from b to a. + */ + def reverse(): JavaGraph[VD, ED] = graph.reverse + + /** + * Restricts the graph to only the vertices and edges satisfying the predicates. The resulting + * subgraph satisifies + * + * {{{ + * V' = {v : for all v in V where vpred(v)} + * E' = {(u,v): for all (u,v) in E where epred((u,v)) && vpred(u) && vpred(v)} + * }}} + * + * @param epred the edge predicate, which takes a triplet and + * evaluates to true if the edge is to remain in the subgraph. Note + * that only edges where both vertices satisfy the vertex + * predicate are considered. + * + * @param vpred the vertex predicate, which takes a vertex object and + * evaluates to true if the vertex is to be included in the subgraph + * + * @return the subgraph containing only the vertices and edges that + * satisfy the predicates + */ + def subgraph( + epred: JFunction[EdgeTriplet[VD, ED], JBoolean], + vpred: JFunction2[JLong, VD, JBoolean]): JavaGraph[VD, ED] = + graph.subgraph(et => epred.call(et), (id, attr) => vpred.call(id, attr)) + + /** + * Restricts the graph to only the vertices and edges that are also in `other`, but keeps the + * attributes from this graph. + * @param other the graph to project this graph onto + * @return a graph with vertices and edges that exist in both the current graph and `other`, + * with vertex and edge data from the current graph + */ + def mask[VD2, ED2](other: JavaGraph[VD2, ED2]): JavaGraph[VD, ED] = + graph.mask(other)(fakeClassTag, fakeClassTag) + + /** + * Merges multiple edges between two vertices into a single edge. For correct results, the graph + * must have been partitioned using [[partitionBy]]. + * + * @param merge the user-supplied commutative associative function to merge edge attributes + * for duplicate edges. + * + * @return The resulting graph with a single edge for each (source, dest) vertex pair. + */ + def groupEdges(merge: JFunction2[ED, ED, ED]): JavaGraph[VD, ED] = + graph.groupEdges(merge) + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * + * @example We can use this function to compute the in-degree of each + * vertex + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") + * val inDeg: RDD[(VertexId, Int)] = + * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * }}} + * + * @note By expressing computation at the edge level we achieve + * maximum parallelism. This is one of the core functions in the + * Graph API in that enables neighborhood level computation. For + * example this function can be used to count neighbors satisfying a + * predicate or implement PageRank. + * + */ + def aggregateMessages[A]( + sendMsg: VoidFunction[EdgeContext[VD, ED, A]], + mergeMsg: JFunction2[A, A, A], + tripletFields: TripletFields): JavaVertexRDD[A] = { + implicit val aTag: ClassTag[A] = fakeClassTag + new JavaVertexRDD(graph.aggregateMessages(ctx => sendMsg.call(ctx), mergeMsg, tripletFields)) + } + + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The + * input table should contain at most one entry for each vertex. If no entry in `other` is + * provided for a particular vertex in the graph, the map function receives `None`. + * + * @tparam U the type of entry in the table of updates + * @tparam VD2 the new vertex value type + * + * @param other the table to join with the vertices in the graph. + * The table should contain at most one entry for each vertex. + * @param mapFunc the function used to compute the new vertex values. + * The map function is invoked for all vertices, even those + * that do not have a corresponding entry in the table. + * + * @example This function is used to update the vertices with new values based on external data. + * For example we could add the out-degree to each vertex record: + * + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("webgraph") + * val outDeg: RDD[(Long, Int)] = rawGraph.outDegrees + * val graph = rawGraph.outerJoinVertices(outDeg) { + * (vid, data, optDeg) => optDeg.getOrElse(0) + * } + * }}} + */ + def outerJoinVertices[U, VD2]( + other: JavaPairRDD[JLong, U], + mapFunc: JFunction3[JLong, VD, Optional[U], VD2]): JavaGraph[VD2, ED] = { + implicit val uTag: ClassTag[U] = fakeClassTag + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + + val scalaOther: RDD[(VertexId, U)] = other.rdd.map(kv => (kv._1, kv._2)) + new JavaGraph(graph.outerJoinVertices(scalaOther) { + (id, a, bOpt) => mapFunc.call(id, a, JavaUtils.optionToOptional(bOpt)) + }) + } + + /** The number of edges in the graph. */ + def numEdges(): Long = graph.numEdges + + /** The number of vertices in the graph. */ + def numVertices(): Long = graph.numVertices + + /** + * The in-degree of each vertex in the graph. + * @note Vertices with no in-edges are not returned in the resulting RDD. + */ + def inDegrees(): JavaVertexRDD[JInt] = graph.inDegrees.mapValues((deg: Int) => deg: JInt) + + /** + * The out-degree of each vertex in the graph. + * @note Vertices with no out-edges are not returned in the resulting RDD. + */ + def outDegrees(): JavaVertexRDD[JInt] = graph.outDegrees.mapValues((deg: Int) => deg: JInt) + + /** + * The degree of each vertex in the graph. + * @note Vertices with no edges are not returned in the resulting RDD. + */ + def degrees(): VertexRDD[JInt] = graph.degrees.mapValues((deg: Int) => deg: JInt) + + /** + * Collect the neighbor vertex ids for each vertex. + * + * @param edgeDirection the direction along which to collect + * neighboring vertices + * + * @return the set of neighboring ids for each vertex + */ + def collectNeighborIds(edgeDirection: EdgeDirection): JavaVertexRDD[Array[Long]] = + graph.collectNeighborIds(edgeDirection) + + /** + * Collect the neighbor vertex attributes for each vertex. + * + * @note This function could be highly inefficient on power-law + * graphs where high degree vertices may force a large ammount of + * information to be collected to a single location. + * + * @param edgeDirection the direction along which to collect + * neighboring vertices + * + * @return the vertex set of neighboring vertex attributes for each vertex + */ + def collectNeighbors(edgeDirection: EdgeDirection): JavaVertexRDD[Array[(JLong, VD)]] = + graph.collectNeighbors(edgeDirection).mapValues((nbrs: Array[(Long, VD)]) => + nbrs.map(kv => (kv._1: JLong, kv._2))) + + /** + * Returns an RDD that contains for each vertex v its local edges, + * i.e., the edges that are incident on v, in the user-specified direction. + * Warning: note that singleton vertices, those with no edges in the given + * direction will not be part of the return value. + * + * @note This function could be highly inefficient on power-law + * graphs where high degree vertices may force a large amount of + * information to be collected to a single location. + * + * @param edgeDirection the direction along which to collect + * the local edges of vertices + * + * @return the local edges for each vertex + */ + def collectEdges(edgeDirection: EdgeDirection): JavaVertexRDD[Array[Edge[ED]]] = + graph.collectEdges(edgeDirection) + + /** + * Join the vertices with an RDD and then apply a function from the + * the vertex and RDD entry to a new vertex value. The input table + * should contain at most one entry for each vertex. If no entry is + * provided the map function is skipped and the old value is used. + * + * @tparam U the type of entry in the table of updates + * @param table the table to join with the vertices in the graph. + * The table should contain at most one entry for each vertex. + * @param mapFunc the function used to compute the new vertex + * values. The map function is invoked only for vertices with a + * corresponding entry in the table otherwise the old vertex value + * is used. + * + * @example This function is used to update the vertices with new + * values based on external data. For example we could add the out + * degree to each vertex record + * + * {{{ + * val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph") + * .mapVertices((_, _) => 0) + * val outDeg = rawGraph.outDegrees + * val graph = rawGraph.joinVertices[Int](outDeg) + * ((_, _, outDeg) => outDeg) + * }}} + * + */ + def joinVertices[U]( + table: JavaPairRDD[JLong, U], + mapFunc: JFunction3[JLong, VD, U, VD]): JavaGraph[VD, ED] = { + implicit val uTag: ClassTag[U] = fakeClassTag + val scalaTable: RDD[(VertexId, U)] = table.rdd.map(kv => (kv._1, kv._2)) + graph.joinVertices(scalaTable) { (vid, a, b) => mapFunc.call(vid, a, b) } + } + + /** + * Filter the graph by computing some values to filter on, and applying the predicates. + * + * @param preprocess a function to compute new vertex and edge data before filtering + * @param epred edge pred to filter on after preprocess, see more details under + * [[org.apache.spark.graphx.Graph#subgraph]] + * @param vpred vertex pred to filter on after prerocess, see more details under + * [[org.apache.spark.graphx.Graph#subgraph]] + * @tparam VD2 vertex type the vpred operates on + * @tparam ED2 edge type the epred operates on + * @return a subgraph of the orginal graph, with its data unchanged + * + * @example This function can be used to filter the graph based on some property, without + * changing the vertex and edge values in your program. For example, we could remove the vertices + * in a graph with 0 outdegree + * + * {{{ + * graph.filter( + * graph => { + * val degrees: VertexRDD[Int] = graph.outDegrees + * graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} + * }, + * vpred = (vid: Long, deg:Int) => deg > 0 + * ) + * }}} + * + */ + def filter[VD2, ED2]( + preprocess: JFunction[JavaGraph[VD, ED], JavaGraph[VD2, ED2]], + epred: JFunction[EdgeTriplet[VD2, ED2], JBoolean], + vpred: JFunction2[JLong, VD2, JBoolean]): JavaGraph[VD, ED] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + implicit val ed2Tag: ClassTag[ED2] = fakeClassTag + graph.filter( + origGraph => preprocess.call(origGraph), + (et: EdgeTriplet[VD2, ED2]) => epred.call(et), + (id, attr: VD2) => vpred.call(id, attr)) + } + + /** + * Picks a random vertex from the graph and returns its ID. + */ + def pickRandomVertex(): Long = graph.pickRandomVertex() + + /** + * Execute a Pregel-like iterative vertex-parallel abstraction. The + * user-defined vertex-program `vprog` is executed in parallel on + * each vertex receiving any inbound messages and computing a new + * value for the vertex. The `sendMsg` function is then invoked on + * all out-edges and is used to compute an optional message to the + * destination vertex. The `mergeMsg` function is a commutative + * associative function used to combine messages destined to the + * same vertex. + * + * On the first iteration all vertices receive the `initialMsg` and + * on subsequent iterations if a vertex does not receive a message + * then the vertex-program is not invoked. + * + * This function iterates until there are no remaining messages, or + * for `maxIterations` iterations. + * + * @tparam A the Pregel message type + * + * @param initialMsg the message each vertex will receive at the on + * the first iteration + * + * @param maxIterations the maximum number of iterations to run for + * + * @param activeDirection the direction of edges incident to a vertex that received a message in + * the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only + * out-edges of vertices that received a message in the previous round will run. + * + * @param vprog the user-defined vertex program which runs on each + * vertex and receives the inbound message and computes a new vertex + * value. On the first iteration the vertex program is invoked on + * all vertices and is passed the default message. On subsequent + * iterations the vertex program is only invoked on those vertices + * that receive messages. + * + * @param sendMsg a user supplied function that is applied to out + * edges of vertices that received messages in the current + * iteration + * + * @param mergeMsg a user supplied function that takes two incoming + * messages of type A and merges them into a single message of type + * A. ''This function must be commutative and associative and + * ideally the size of A should not increase.'' + * + * @return the resulting graph at the end of the computation + * + */ + def pregel[A]( + initialMsg: A, + maxIterations: Int, + activeDirection: EdgeDirection, + vprog: JFunction3[JLong, VD, A, VD], + sendMsg: PairFlatMapFunction[EdgeTriplet[VD, ED], JLong, A], + mergeMsg: JFunction2[A, A, A]) + : JavaGraph[VD, ED] = { + implicit val aTag: ClassTag[A] = fakeClassTag + + import scala.collection.JavaConverters._ + def scalaSendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)] = + (e: EdgeTriplet[VD, ED]) => sendMsg.call(e).asScala.iterator.map(kv => (kv._1, kv._2)) + + graph.pregel(initialMsg, maxIterations, activeDirection)( + (vid, a, b) => vprog.call(vid, a, b), scalaSendMsg, mergeMsg) + } + + /** + * Run a dynamic version of PageRank returning a graph with vertex attributes containing the + * PageRank and edge attributes containing the normalized edge weight. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergence]] + */ + def pageRank(tol: Double, resetProb: Double): JavaGraph[Double, Double] = + graph.pageRank(tol, resetProb) + + /** + * Run PageRank for a fixed number of iterations returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#run]] + */ + def staticPageRank(numIter: Int, resetProb: Double): JavaGraph[Double, Double] = + graph.staticPageRank(numIter, resetProb) + + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] + */ + def connectedComponents(): JavaGraph[JLong, ED] = + graph.connectedComponents().mapVertices((id, cc) => (cc: JLong)) + + /** + * Compute the number of triangles passing through each vertex. + * + * @see [[org.apache.spark.graphx.lib.TriangleCount$#run]] + */ + def triangleCount(): JavaGraph[JInt, ED] = + graph.triangleCount().mapVertices((id, numTriangles) => (numTriangles: JInt)) + + /** + * Compute the strongly connected component (SCC) of each vertex and return a graph with the + * vertex value containing the lowest vertex id in the SCC containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.StronglyConnectedComponents$#run]] + */ + def stronglyConnectedComponents(numIter: Int): JavaGraph[JLong, ED] = + graph.stronglyConnectedComponents(numIter).mapVertices((id, scc) => (scc: JLong)) +} + +object JavaGraph { + + /** + * Construct a graph from a collection of edges encoded as vertex id pairs. + * + * @param rawEdges a collection of edges in (src, dst) form + * @param defaultValue the vertex attributes with which to create vertices referenced by the edges + * + * @return a graph with all edge attributes set to 1, all vertex attributes set to + * `defaultValue`, and target storage level set to `MEMORY_ONLY`. + */ + def fromEdgeTuples[VD]( + rawEdges: JavaPairRDD[JLong, JLong], + defaultValue: VD): JavaGraph[VD, JInt] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaRawEdges: RDD[(VertexId, VertexId)] = rawEdges.rdd.map(kv => (kv._1, kv._2)) + Graph.fromEdgeTuples(scalaRawEdges, defaultValue, uniqueEdges = None) + .mapEdges((e: Edge[Int]) => (e.attr: JInt)) + } + + /** + * Construct a graph from a collection of edges encoded as vertex id pairs. + * + * @param rawEdges a collection of edges in (src, dst) form + * @param defaultValue the vertex attributes with which to create vertices referenced by the edges + * @param edgeStorageLevel the desired storage level at which to cache the edges if necessary + * @param vertexStorageLevel the desired storage level at which to cache the vertices if necessary + * + * @return a graph with all edge attributes set to 1 and all vertex attributes set to + * `defaultValue`. + */ + def fromEdgeTuples[VD]( + rawEdges: JavaPairRDD[JLong, JLong], + defaultValue: VD, + edgeStorageLevel: StorageLevel, + vertexStorageLevel: StorageLevel): JavaGraph[VD, JInt] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaRawEdges: RDD[(VertexId, VertexId)] = rawEdges.rdd.map(kv => (kv._1, kv._2)) + Graph.fromEdgeTuples(scalaRawEdges, defaultValue, uniqueEdges = None, + edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) + .mapEdges((e: Edge[Int]) => (e.attr: JInt)) + } + + /** + * Construct a graph from a collection of edges encoded as vertex id pairs, repartitioning edges + * and merging duplicate edges. + * + * @param rawEdges a collection of edges in (src, dst) form + * @param defaultValue the vertex attributes with which to create vertices referenced by the edges + * @param partitionStrategy how to repartition the edges. In addition to repartitioning, if + * multiple identical edges are found they are combined and the edge attribute is set to the + * number of merged edges + * @param edgeStorageLevel the desired storage level at which to cache the edges if necessary + * @param vertexStorageLevel the desired storage level at which to cache the vertices if necessary + * + * @return a graph with edge attributes containing the count of duplicate edges and all vertex + * attributes set to `defaultValue`. + */ + def fromEdgeTuples[VD]( + rawEdges: JavaPairRDD[JLong, JLong], + defaultValue: VD, + partitionStrategy: PartitionStrategy, + edgeStorageLevel: StorageLevel, + vertexStorageLevel: StorageLevel): JavaGraph[VD, JInt] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaRawEdges: RDD[(VertexId, VertexId)] = rawEdges.rdd.map(kv => (kv._1, kv._2)) + Graph.fromEdgeTuples(scalaRawEdges, defaultValue, uniqueEdges = Some(partitionStrategy), + edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) + .mapEdges((e: Edge[Int]) => (e.attr: JInt)) + } + + /** + * Construct a graph from a collection of edges. + * + * @param edges the RDD containing the set of edges in the graph + * @param defaultValue the default vertex attribute to use for each vertex + * + * @return a graph with edge attributes described by `edges`, vertices + * given by all vertices in `edges` with value `defaultValue`, and target storage level + * set to `MEMORY_ONLY`. + */ + def fromEdges[VD, ED]( + edges: JavaRDD[Edge[ED]], + defaultValue: VD): JavaGraph[VD, ED] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + implicit val edTag: ClassTag[ED] = fakeClassTag + Graph.fromEdges(edges, defaultValue) + } + + /** + * Construct a graph from a collection of edges. + * + * @param edges the RDD containing the set of edges in the graph + * @param defaultValue the default vertex attribute to use for each vertex + * @param edgeStorageLevel the desired storage level at which to cache the edges if necessary + * @param vertexStorageLevel the desired storage level at which to cache the vertices if necessary + * + * @return a graph with edge attributes described by `edges` and vertices + * given by all vertices in `edges` with value `defaultValue` + */ + def fromEdges[VD, ED]( + edges: JavaRDD[Edge[ED]], + defaultValue: VD, + edgeStorageLevel: StorageLevel, + vertexStorageLevel: StorageLevel): JavaGraph[VD, ED] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + implicit val edTag: ClassTag[ED] = fakeClassTag + Graph.fromEdges(edges, defaultValue, edgeStorageLevel, vertexStorageLevel) + } + + /** + * Construct a graph from a collection of vertices and + * edges with attributes. Duplicate vertices are picked arbitrarily and + * vertices found in the edge collection but not in the input + * vertices are assigned the default attribute. + * + * @tparam VD the vertex attribute type + * @tparam ED the edge attribute type + * @param vertices the "set" of vertices and their attributes + * @param edges the collection of edges in the graph + * @param defaultVertexAttr the default vertex attribute to use for vertices that are + * mentioned in edges but not in vertices + */ + def create[VD, ED]( + vertices: JavaPairRDD[JLong, VD], + edges: JavaRDD[Edge[ED]], + defaultVertexAttr: VD): JavaGraph[VD, ED] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + implicit val edTag: ClassTag[ED] = fakeClassTag + val scalaVertices: RDD[(VertexId, VD)] = vertices.rdd.map(kv => (kv._1, kv._2)) + Graph(scalaVertices, edges, defaultVertexAttr) + } + + /** + * Construct a graph from a collection of vertices and + * edges with attributes. Duplicate vertices are picked arbitrarily and + * vertices found in the edge collection but not in the input + * vertices are assigned the default attribute. + * + * @tparam VD the vertex attribute type + * @tparam ED the edge attribute type + * @param vertices the "set" of vertices and their attributes + * @param edges the collection of edges in the graph + * @param defaultVertexAttr the default vertex attribute to use for vertices that are + * mentioned in edges but not in vertices + * @param edgeStorageLevel the desired storage level at which to cache the edges if necessary + * @param vertexStorageLevel the desired storage level at which to cache the vertices if necessary + */ + def create[VD, ED]( + vertices: JavaPairRDD[JLong, VD], + edges: JavaRDD[Edge[ED]], + defaultVertexAttr: VD, + edgeStorageLevel: StorageLevel, + vertexStorageLevel: StorageLevel): JavaGraph[VD, ED] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + implicit val edTag: ClassTag[ED] = fakeClassTag + val scalaVertices: RDD[(VertexId, VD)] = vertices.rdd.map(kv => (kv._1, kv._2)) + Graph(scalaVertices, edges, defaultVertexAttr, edgeStorageLevel, vertexStorageLevel) + } + + implicit def fromGraph[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): JavaGraph[VD, ED] = + new JavaGraph[VD, ED](graph) + + implicit def toGraph[VD, ED](graph: JavaGraph[VD, ED]): Graph[VD, ED] = + graph.graph +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala new file mode 100644 index 0000000000000..cbed8922363f7 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.api.java + +import java.lang.{Integer => JInt, Long => JLong} + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.google.common.base.Optional +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, + Function3 => JFunction3, _} +import org.apache.spark.graphx._ +import org.apache.spark.rdd.RDD + +class JavaVertexRDD[VD](val vertexRDD: VertexRDD[VD])(implicit val vdTag: ClassTag[VD]) + extends JavaPairRDD[JLong, VD](vertexRDD.map(kv => (kv._1, kv._2))) { + + /** + * Maps each vertex attribute, preserving the index. + * + * @tparam VD2 the type returned by the map function + * + * @param f the function applied to each value in the RDD + * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the + * original VertexRDD + */ + override def mapValues[VD2](f: JFunction[VD, VD2]): JavaVertexRDD[VD2] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + vertexRDD.mapValues(f) + } + + /** + * Maps each vertex attribute, additionally supplying the vertex ID. + * + * @tparam VD2 the type returned by the map function + * + * @param f the function applied to each ID-value pair in the RDD + * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the + * original VertexRDD. The resulting VertexRDD retains the same index. + */ + def mapValues[VD2](f: JFunction2[JLong, VD, VD2]): JavaVertexRDD[VD2] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + vertexRDD.mapValues((id: VertexId, v) => f.call(id, v)) + } + + /** + * Hides vertices that are the same between `this` and `other`; for vertices that are different, + * keeps the values from `other`. + */ + def diff(other: JavaVertexRDD[VD]): JavaVertexRDD[VD] = + vertexRDD.diff(other) + + /** + * Left joins this RDD with another VertexRDD with the same index. This function will fail if + * both VertexRDDs do not share the same index. The resulting vertex set contains an entry for + * each vertex in `this`. + * If `other` is missing any vertex in this VertexRDD, `f` is passed `None`. + * + * @tparam VD2 the attribute type of the other VertexRDD + * @tparam VD3 the attribute type of the resulting VertexRDD + * + * @param other the other VertexRDD with which to join. + * @param f the function mapping a vertex id and its attributes in this and the other vertex set + * to a new vertex attribute. + * @return a VertexRDD containing the results of `f` + */ + def leftZipJoin[VD2, VD3]( + other: JavaVertexRDD[VD2], + f: JFunction3[JLong, VD, Optional[VD2], VD3]): JavaVertexRDD[VD3] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + implicit val vd3Tag: ClassTag[VD3] = fakeClassTag + vertexRDD.leftZipJoin(other) { (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) } + } + + /** + * Left joins this JavaVertexRDD with an RDD containing vertex attribute pairs. If the other RDD is + * backed by a JavaVertexRDD with the same index then the efficient [[leftZipJoin]] implementation is + * used. The resulting JavaVertexRDD contains an entry for each vertex in `this`. If `other` is + * missing any vertex in this JavaVertexRDD, `f` is passed `None`. If there are duplicates, + * the vertex is picked arbitrarily. + * + * @tparam VD2 the attribute type of the other JavaVertexRDD + * @tparam VD3 the attribute type of the resulting JavaVertexRDD + * + * @param other the other JavaVertexRDD with which to join + * @param f the function mapping a vertex id and its attributes in this and the other vertex set + * to a new vertex attribute. + * @return a JavaVertexRDD containing all the vertices in this JavaVertexRDD with the attributes emitted + * by `f`. + */ + def leftJoin[VD2, VD3]( + other: JavaPairRDD[JLong, VD2], + f: JFunction3[JLong, VD, Optional[VD2], VD3]) + : JavaVertexRDD[VD3] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + implicit val vd3Tag: ClassTag[VD3] = fakeClassTag + val scalaOther: RDD[(VertexId, VD2)] = other.rdd.map(kv => (kv._1, kv._2)) + vertexRDD.leftJoin(scalaOther) { (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) } + } + + /** + * Efficiently inner joins this JavaVertexRDD with another JavaVertexRDD sharing the same index. See + * [[innerJoin]] for the behavior of the join. + */ + def innerZipJoin[U, VD2]( + other: JavaVertexRDD[U], + f: JFunction3[JLong, VD, U, VD2]): JavaVertexRDD[VD2] = { + implicit val uTag: ClassTag[U] = fakeClassTag + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + vertexRDD.innerZipJoin(other) { (id, a, b) => f.call(id, a, b) } + } + + /** + * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is + * backed by a VertexRDD with the same index then the efficient [[innerZipJoin]] implementation + * is used. + * + * @param other an RDD containing vertices to join. If there are multiple entries for the same + * vertex, one is picked arbitrarily. Use [[aggregateUsingIndex]] to merge multiple entries. + * @param f the join function applied to corresponding values of `this` and `other` + * @return a VertexRDD co-indexed with `this`, containing only vertices that appear in both + * `this` and `other`, with values supplied by `f` + */ + def innerJoin[U, VD2]( + other: JavaPairRDD[JLong, U], + f: JFunction3[JLong, VD, U, VD2]): JavaVertexRDD[VD2] = { + implicit val uTag: ClassTag[U] = fakeClassTag + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + val scalaOther: RDD[(VertexId, U)] = other.rdd.map(kv => (kv._1, kv._2)) + vertexRDD.innerJoin(scalaOther) { (id, a, b) => f.call(id, a, b) } + } + + /** + * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a + * VertexRDD co-indexed with `this`. + * + * @param messages an RDD containing messages to aggregate, where each message is a pair of its + * target vertex ID and the message data + * @param reduceFunc the associative aggregation function for merging messages to the same vertex + * @return a VertexRDD co-indexed with `this`, containing only vertices that received messages. + * For those vertices, their values are the result of applying `reduceFunc` to all received + * messages. + */ + def aggregateUsingIndex[VD2]( + messages: JavaPairRDD[JLong, VD2], + reduceFunc: JFunction2[VD2, VD2, VD2]): JavaVertexRDD[VD2] = { + implicit val vd2Tag: ClassTag[VD2] = fakeClassTag + val scalaMessages: RDD[(VertexId, VD2)] = messages.rdd.map(kv => (kv._1, kv._2)) + vertexRDD.aggregateUsingIndex(scalaMessages, reduceFunc) + } + + /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */ + def withEdges(edges: JavaEdgeRDD[_, _]): JavaVertexRDD[VD] = vertexRDD.withEdges(edges) +} + +object JavaVertexRDD { + + /** + * Constructs a standalone `VertexRDD` (one that is not set up for efficient joins with an + * [[EdgeRDD]]) from an RDD of vertex-attribute pairs. Duplicate entries are removed arbitrarily. + * + * @tparam VD the vertex attribute type + * + * @param vertices the collection of vertex-attribute pairs + */ + def create[VD](vertices: JavaPairRDD[JLong, VD]): JavaVertexRDD[VD] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaVertices: RDD[(VertexId, VD)] = vertices.rdd.map(kv => (kv._1, kv._2)) + VertexRDD(scalaVertices) + } + + /** + * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs. Duplicate vertex entries are + * removed arbitrarily. The resulting `VertexRDD` will be joinable with `edges`, and any missing + * vertices referred to by `edges` will be created with the attribute `defaultVal`. + * + * @tparam VD the vertex attribute type + * + * @param vertices the collection of vertex-attribute pairs + * @param edges the [[EdgeRDD]] that these vertices may be joined with + * @param defaultVal the vertex attribute to use when creating missing vertices + */ + def create[VD](vertices: JavaPairRDD[JLong, VD], edges: JavaEdgeRDD[_, _], defaultVal: VD) + : JavaVertexRDD[VD] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaVertices: RDD[(VertexId, VD)] = vertices.rdd.map(kv => (kv._1, kv._2)) + VertexRDD(scalaVertices, edges, defaultVal) + } + + /** + * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs. Duplicate vertex entries are + * merged using `mergeFunc`. The resulting `VertexRDD` will be joinable with `edges`, and any + * missing vertices referred to by `edges` will be created with the attribute `defaultVal`. + * + * @tparam VD the vertex attribute type + * + * @param vertices the collection of vertex-attribute pairs + * @param edges the [[EdgeRDD]] that these vertices may be joined with + * @param defaultVal the vertex attribute to use when creating missing vertices + * @param mergeFunc the commutative, associative duplicate vertex attribute merge function + */ + def create[VD]( + vertices: JavaPairRDD[JLong, VD], edges: JavaEdgeRDD[_, _], defaultVal: VD, + mergeFunc: JFunction2[VD, VD, VD]): JavaVertexRDD[VD] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + val scalaVertices: RDD[(VertexId, VD)] = vertices.rdd.map(kv => (kv._1, kv._2)) + VertexRDD(scalaVertices, edges, defaultVal, mergeFunc) + } + + /** + * Constructs a `VertexRDD` containing all vertices referred to in `edges`. The vertices will be + * created with the attribute `defaultVal`. The resulting `VertexRDD` will be joinable with + * `edges`. + * + * @tparam VD the vertex attribute type + * + * @param edges the [[EdgeRDD]] referring to the vertices to create + * @param numPartitions the desired number of partitions for the resulting `VertexRDD` + * @param defaultVal the vertex attribute to use when creating missing vertices + */ + def fromEdges[VD]( + edges: JavaEdgeRDD[_, _], numPartitions: Int, defaultVal: VD): JavaVertexRDD[VD] = { + implicit val vdTag: ClassTag[VD] = fakeClassTag + VertexRDD.fromEdges(edges, numPartitions, defaultVal) + } + + implicit def fromVertexRDD[VD: ClassTag](rdd: VertexRDD[VD]): JavaVertexRDD[VD] = + new JavaVertexRDD[VD](rdd) + + implicit def toVertexRDD[VD](rdd: JavaVertexRDD[VD]): VertexRDD[VD] = + rdd.vertexRDD +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/PartitionStrategies.java b/graphx/src/main/scala/org/apache/spark/graphx/api/java/PartitionStrategies.java new file mode 100644 index 0000000000000..5269005558aa2 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/PartitionStrategies.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.api.java; + +import org.apache.spark.graphx.PartitionStrategy; + +/** + * Expose the partition strategies as constants. + * + * @see org.apache.spark.graphx.PartitionStrategy + */ +public class PartitionStrategies { + public static final PartitionStrategy EdgePartition2D = + PartitionStrategy.EdgePartition2D$.MODULE$; + public static final PartitionStrategy EdgePartition1D = + PartitionStrategy.EdgePartition1D$.MODULE$; + public static final PartitionStrategy RandomVertexCut = + PartitionStrategy.RandomVertexCut$.MODULE$; + public static final PartitionStrategy CanonicalRandomVertexCut = + PartitionStrategy.CanonicalRandomVertexCut$.MODULE$; +} diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java new file mode 100644 index 0000000000000..d9e429415f656 --- /dev/null +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.*; +import java.net.URI; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.base.Optional; +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.SequenceFileInputFormat; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.Job; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.partial.BoundedDouble; +import org.apache.spark.partial.PartialResult; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.StatCounter; + +import org.apache.spark.graphx.*; +import org.apache.spark.graphx.api.java.*; + +public class JavaEdgeRDDSuite implements Serializable { + private transient JavaSparkContext sc; + private transient File tempDir; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaEdgeRDDSuite"); + tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + private JavaEdgeRDD edgeRDD(long n) { + List> edges = new ArrayList>(); + for (long i = 0; i < n; i++) { + edges.add(new Edge(i, (i + 1) % n, "e")); + } + return JavaEdgeRDD.fromEdges(sc.parallelize(edges, 3)); + } + + @SuppressWarnings("unchecked") + @Test + public void edgeRDDFromEdges() { + long n = 100; + JavaEdgeRDD edges = edgeRDD(n); + + Assert.assertEquals(n, edges.count()); + } + + @SuppressWarnings("unchecked") + @Test + public void mapValues() { + long n = 100; + JavaEdgeRDD edges = edgeRDD(n); + JavaEdgeRDD mapped = edges.mapValues( + new Function, Integer>() { + public Integer call(Edge e) { + return (int)e.srcId(); + } + }); + Assert.assertEquals(n, mapped.count()); + + } + + @SuppressWarnings("unchecked") + @Test + public void reverse() { + long n = 100; + JavaEdgeRDD edges = edgeRDD(n); + Assert.assertEquals(n, edges.reverse().count()); + } + + @SuppressWarnings("unchecked") + @Test + public void filter() { + long n = 100; + JavaEdgeRDD edges = edgeRDD(n); + JavaEdgeRDD filtered = edges.filter( + new Function, Boolean>() { + public Boolean call(EdgeTriplet e) { + return e.attr().equals("e"); + } + }, + new Function2() { + public Boolean call(Long id, Integer attr) { + return id < 10; + } + }); + Assert.assertEquals(9, filtered.count()); + } + + @SuppressWarnings("unchecked") + @Test + public void innerJoin() { + long n = 100; + JavaEdgeRDD a = edgeRDD(n); + JavaEdgeRDD b = a.filter( + new Function, Boolean>() { + public Boolean call(EdgeTriplet e) { + return true; + } + }, + new Function2() { + public Boolean call(Long id, Integer attr) { + return id < 10; + } + }); + + JavaEdgeRDD joined = a.innerJoin( + b, + new Function4() { + public String call(Long src, Long dst, String a, String b) { + return a + b; + } + }); + + for (Edge e : joined.collect()) { + Assert.assertEquals("ee", e.attr()); + } + Assert.assertEquals(b.count(), joined.count()); + } +} diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java new file mode 100644 index 0000000000000..c59aa26f46694 --- /dev/null +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.*; +import java.net.URI; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.base.Optional; +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.SequenceFileInputFormat; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.Job; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.partial.BoundedDouble; +import org.apache.spark.partial.PartialResult; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.StatCounter; + +import org.apache.spark.graphx.*; +import org.apache.spark.graphx.api.java.*; + +public class JavaGraphSuite implements Serializable { + private transient JavaSparkContext sc; + private transient File tempDir; + + private JavaGraph starGraph(long n) { + List> edges = new ArrayList>(); + for (long i = 1; i < n; i++) { + edges.add(new Tuple2(0L, i)); + } + return JavaGraph.fromEdgeTuples(sc.parallelizePairs(edges, 3), "v"); + } + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGraphSuite"); + tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void graphFromEdgeTuples() { + List> doubleRing = new ArrayList>(); + long n = 100; + for (long i = 0; i < n; i++) { + doubleRing.add(new Tuple2(i, (i + 1) % n)); + doubleRing.add(new Tuple2(i, (i + 1) % n)); + } + + JavaGraph graph = + JavaGraph.fromEdgeTuples(sc.parallelizePairs(doubleRing), 1); + Assert.assertEquals(doubleRing.size(), graph.edges().count()); + for (Edge e : graph.edges().collect()) { + Assert.assertEquals(1, e.attr.longValue()); + } + + // uniqueEdges option should uniquify edges and store duplicate count in edge attributes + JavaGraph uniqueGraph = JavaGraph.fromEdgeTuples( + sc.parallelizePairs(doubleRing), 1, PartitionStrategies.RandomVertexCut, + StorageLevel.MEMORY_ONLY(), StorageLevel.MEMORY_ONLY()); + Assert.assertEquals(n, uniqueGraph.edges().count()); + for (Edge e : uniqueGraph.edges().collect()) { + Assert.assertEquals(2, e.attr.longValue()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void graphFromEdges() { + List> ring = new ArrayList>(); + long n = 100; + for (long i = 0; i < n; i++) { + ring.add(new Edge(i, (i + 1) % n, 1)); + } + + JavaGraph graph = JavaGraph.fromEdges(sc.parallelize(ring), 1.0f); + Assert.assertEquals(ring.size(), graph.edges().count()); + } + + @SuppressWarnings("unchecked") + @Test + public void graphCreate() { + List> edges = new ArrayList>(); + long n = 100; + for (long i = 0; i < n; i++) { + edges.add(new Edge(i, (i + 1) % n, 1)); + } + + List> vertices = new ArrayList>(); + for (long i = 0; i < 10; i++) { + vertices.add(new Tuple2(i, true)); + } + + JavaGraph graph = + JavaGraph.create(sc.parallelizePairs(vertices), sc.parallelize(edges), false); + + Assert.assertEquals(edges.size(), graph.edges().count()); + + // Vertices not explicitly provided but referenced by edges should be created automatically + Assert.assertEquals(100, graph.vertices().count()); + + for (EdgeTriplet et : graph.triplets().collect()) { + Assert.assertTrue((et.srcId() < 10 && et.srcAttr()) || (et.srcId() >= 10 && !et.srcAttr())); + Assert.assertTrue((et.dstId() < 10 && et.dstAttr()) || (et.dstId() >= 10 && !et.dstAttr())); + } + } + + @SuppressWarnings("unchecked") + @Test + public void triplets() { + long n = 5; + JavaGraph star = starGraph(n); + JavaRDD> triplets = star.triplets().map( + new Function, Tuple4>() { + public Tuple4 call(EdgeTriplet et) { + return new Tuple4( + et.srcId(), et.dstId(), et.srcAttr(), et.dstAttr()); + } + }); + + Set> tripletsExpected = + new HashSet>(); + for (long i = 1; i < n; i++) { + tripletsExpected.add(new Tuple4(0L, i, "v", "v")); + } + Assert.assertEquals( + tripletsExpected, new HashSet>(triplets.collect())); + } + + @SuppressWarnings("unchecked") + @Test + public void partitionBy() { + JavaGraph star = starGraph(10); + JavaGraph star2D = star.partitionBy(PartitionStrategies.EdgePartition2D); + } + + @SuppressWarnings("unchecked") + @Test + public void mapVertices() { + long n = 5; + JavaGraph star = starGraph(n); + star.mapVertices( + new Function2() { + public String call(Long id, String attr) { + return attr + id; + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void mapEdges() { + long n = 5; + JavaGraph star = starGraph(n); + star.mapEdges( + new Function, String>() { + public String call(Edge e) { + return e.attr().toString(); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void mapTriplets() { + long n = 5; + JavaGraph star = starGraph(n); + star.mapTriplets( + new Function, String>() { + public String call(EdgeTriplet et) { + return et.srcAttr() + et.dstAttr(); + } + }, TripletFields.SrcDstOnly); + } + + @SuppressWarnings("unchecked") + @Test + public void reverse() { + long n = 5; + JavaGraph star = starGraph(n); + star.reverse(); + } + + @SuppressWarnings("unchecked") + @Test + public void subgraph() { + long n = 5; + JavaGraph star = starGraph(n); + star.subgraph( + new Function, Boolean>() { + public Boolean call(EdgeTriplet et) { + return et.attr() == 1; + } + }, + new Function2() { + public Boolean call(Long id, String attr) { + return id > 3; + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void mask() { + long n = 5; + JavaGraph star = starGraph(n); + JavaGraph star2 = starGraph(n + 5); + star2.mask(star); + } + + @SuppressWarnings("unchecked") + @Test + public void groupEdges() { + long n = 5; + JavaGraph star = starGraph(n); + star.groupEdges( + new Function2() { + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void aggregateMessages() { + long n = 5; + JavaGraph star = starGraph(n); + VoidFunction> sendMsg = + new VoidFunction>() { + public void call(EdgeContext ctx) { + ctx.sendToDst(ctx.srcAttr()); + } + }; + Function2 mergeMsg = + new Function2() { + public String call(String a, String b) { + return a + b; + } + }; + star.aggregateMessages(sendMsg, mergeMsg, TripletFields.SrcOnly); + } + + @SuppressWarnings("unchecked") + @Test + public void outerJoinVertices() { + long n = 5; + JavaGraph star = starGraph(n); + star.outerJoinVertices( + star.vertices(), + new Function3, String>() { + public String call(Long id, String a, Optional bOpt) { + return a + bOpt.or(""); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void simpleOps() { + long n = 5; + JavaGraph star = starGraph(n); + star.numEdges(); + star.numVertices(); + star.inDegrees(); + star.outDegrees(); + star.degrees(); + star.collectNeighborIds(EdgeDirection.Out()); + star.collectNeighbors(EdgeDirection.Out()); + star.collectEdges(EdgeDirection.Out()); + star.pickRandomVertex(); + star.pageRank(0.01, 0.15); + star.staticPageRank(10, 0.15); + star.connectedComponents(); + star.triangleCount(); + star.stronglyConnectedComponents(10); + } + + @SuppressWarnings("unchecked") + @Test + public void joinVertices() { + long n = 5; + JavaGraph star = starGraph(n); + star.joinVertices( + star.vertices(), + new Function3() { + public String call(Long id, String a, String b) { + return a + b; + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void filter() { + long n = 5; + JavaGraph star = starGraph(n); + star.filter( + new Function, JavaGraph>() { + public JavaGraph call(JavaGraph graph) { + return graph; + } + }, + new Function, Boolean>() { + public Boolean call(EdgeTriplet et) { + return et.attr() == 1; + } + }, + new Function2() { + public Boolean call(Long id, String attr) { + return id > 3; + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void pregel() { + long n = 5; + JavaGraph star = starGraph(n); + star.pregel( + "", 10, EdgeDirection.Either(), + new Function3() { + public String call(Long id, String attr, String msg) { + return attr + msg; + } + }, + new PairFlatMapFunction, Long, String>() { + public Iterable> call(EdgeTriplet et) { + List> msgs = new ArrayList>(); + msgs.add(new Tuple2(et.dstId(), et.srcAttr())); + return msgs; + } + }, + new Function2() { + public String call(String a, String b) { + return a + b; + } + }); + } +} diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java new file mode 100644 index 0000000000000..04c1f2491d17d --- /dev/null +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.*; +import java.net.URI; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.base.Optional; +import com.google.common.base.Charsets; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.SequenceFileInputFormat; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.Job; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.partial.BoundedDouble; +import org.apache.spark.partial.PartialResult; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.StatCounter; + +import org.apache.spark.graphx.*; +import org.apache.spark.graphx.api.java.*; + +public class JavaVertexRDDSuite implements Serializable { + private transient JavaSparkContext sc; + private transient File tempDir; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVertexRDDSuite"); + tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + private JavaVertexRDD vertexRDD(long n) { + return vertexRDD(n, 1); + } + + private JavaVertexRDD vertexRDD(long n, int val) { + return JavaVertexRDD.create(pairRDD(n, val)); + } + + private JavaPairRDD pairRDD(long n, int val) { + List> tuples = new ArrayList>(); + for (long i = 0; i < n; i++) { + tuples.add(new Tuple2(i, val)); + } + return sc.parallelizePairs(tuples, 3); + } + + @SuppressWarnings("unchecked") + @Test + public void vertexRDDCreate() { + List> tuples = new ArrayList>(); + long n = 100; + for (long i = 0; i < n; i++) { + tuples.add(new Tuple2(i, 1)); + } + + JavaVertexRDD vertexRDD1 = JavaVertexRDD.create(sc.parallelizePairs(tuples)); + Assert.assertEquals( + new HashSet>(tuples), + new HashSet>(vertexRDD1.collect())); + + // Create a graph so we can use its JavaEdgeRDD to construct a JavaVertexRDD + List> ring = new ArrayList>(); + long m = 200; + for (long i = 0; i < m; i++) { + ring.add(new Edge(i, (i + 1) % m, 1)); + } + JavaGraph graph = JavaGraph.fromEdges(sc.parallelize(ring), 1); + + JavaVertexRDD vertexRDD2 = + JavaVertexRDD.create(sc.parallelizePairs(tuples), graph.edges(), 2); + Assert.assertEquals(m, vertexRDD2.count()); + + List> duplicateTuples = new ArrayList>(); + for (long i = 0; i < n; i++) { + duplicateTuples.add(new Tuple2(i, 1)); + duplicateTuples.add(new Tuple2(i, 2)); + } + JavaGraph emptyGraph = + JavaGraph.fromEdges(sc.parallelize(new ArrayList>()), 1); + + JavaVertexRDD vertexRDD3 = JavaVertexRDD.create( + sc.parallelizePairs(duplicateTuples), emptyGraph.edges(), 2, + new Function2() { + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(n, vertexRDD3.count()); + for (Tuple2 kv : vertexRDD3.collect()) { + Assert.assertEquals(3, kv._2().intValue()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void vertexRDDFromEdges() { + // Create a graph so we can use its JavaEdgeRDD to construct a JavaVertexRDD + List> ring = new ArrayList>(); + long m = 200; + for (long i = 0; i < m; i++) { + ring.add(new Edge(i, (i + 1) % m, 1)); + } + JavaGraph graph = JavaGraph.fromEdges(sc.parallelize(ring), 1); + + JavaVertexRDD vertexRDD = JavaVertexRDD.fromEdges(graph.edges(), 1, 1); + Assert.assertEquals(m, vertexRDD.count()); + } + + @SuppressWarnings("unchecked") + @Test + public void mapValues() { + vertexRDD(100).mapValues( + new Function() { + public Integer call(Integer v) { + return v * 2; + } + }).collect(); + + vertexRDD(100).mapValues( + new Function2() { + public Long call(Long id, Integer v) { + return id + v * 2; + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void diff() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaVertexRDD b = vertexRDD(100, 2); + a.diff(b).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void leftZipJoin() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaVertexRDD b = vertexRDD(100, 2); + a.leftZipJoin(b, + new Function3, Integer>() { + public Integer call(Long id, Integer a, Optional bOpt) { + return a + bOpt.or(0); + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void leftJoin() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaPairRDD b = pairRDD(100, 2); + a.leftJoin(b, + new Function3, Integer>() { + public Integer call(Long id, Integer a, Optional bOpt) { + return a + bOpt.or(0); + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void innerZipJoin() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaVertexRDD b = vertexRDD(100, 2); + a.innerZipJoin(b, + new Function3() { + public Integer call(Long id, Integer a, Integer b) { + return a + b; + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void innerJoin() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaPairRDD b = pairRDD(100, 2); + a.innerJoin(b, + new Function3() { + public Integer call(Long id, Integer a, Integer b) { + return a + b; + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void aggregateUsingIndex() { + JavaVertexRDD a = vertexRDD(100, 1); + JavaPairRDD b = pairRDD(100, 2); + a.aggregateUsingIndex(b, + new Function2() { + public Integer call(Integer a, Integer b) { + return a + b; + } + }).collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void withEdges() { + // Create a graph so we can use its JavaEdgeRDD + List> ring = new ArrayList>(); + long m = 200; + for (long i = 0; i < m; i++) { + ring.add(new Edge(i, (i + 1) % m, 1)); + } + JavaGraph graph = JavaGraph.fromEdges(sc.parallelize(ring), 1); + + JavaVertexRDD vertexRDD = vertexRDD(100, 1); + vertexRDD.withEdges(graph.edges()).collect(); + } +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index df773db6e4326..8aca24c2148a7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect.foreach { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } From 0d2f4c76c4a33aa5333f19937b0a002d4a34379e Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Nov 2014 15:36:13 -0800 Subject: [PATCH 2/3] Organize imports --- .../spark/graphx/api/java/JavaEdgeRDD.scala | 4 +- .../spark/graphx/api/java/JavaGraph.scala | 6 +-- .../apache/spark/graphx/JavaEdgeRDDSuite.java | 38 +++---------------- .../apache/spark/graphx/JavaGraphSuite.java | 38 +++++-------------- .../spark/graphx/JavaVertexRDDSuite.java | 36 ++++-------------- 5 files changed, 27 insertions(+), 95 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala index 0e9153d479fb6..288f579a0874f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaEdgeRDD.scala @@ -23,13 +23,13 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, Function4 => JFunction4, _} import org.apache.spark.graphx._ -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.java.JavaPairRDD._ class JavaEdgeRDD[ED, VD](override val rdd: EdgeRDD[ED, VD])( implicit val edTag: ClassTag[ED], val vdTag: ClassTag[VD]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala index c702ffa246436..ac01b9172ec15 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaGraph.scala @@ -24,18 +24,18 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.api.java.JavaUtils import com.google.common.base.Optional +import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.graphx._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel class JavaGraph[VD, ED](val graph: Graph[VD, ED])( implicit val vdTag: ClassTag[VD], implicit val edTag: ClassTag[ED]) extends Serializable { diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java index d9e429415f656..27ee2b4443609 100644 --- a/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaEdgeRDDSuite.java @@ -17,45 +17,19 @@ package org.apache.spark.graphx; -import java.io.*; -import java.net.URI; +import java.io.File; +import java.io.Serializable; import java.util.*; -import scala.Tuple2; -import scala.Tuple3; -import scala.Tuple4; - -import com.google.common.collect.Iterables; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.base.Optional; -import com.google.common.base.Charsets; import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.compress.DefaultCodec; -import org.apache.hadoop.mapred.SequenceFileInputFormat; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapreduce.Job; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.partial.BoundedDouble; -import org.apache.spark.partial.PartialResult; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.util.StatCounter; - import org.apache.spark.graphx.*; import org.apache.spark.graphx.api.java.*; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; public class JavaEdgeRDDSuite implements Serializable { private transient JavaSparkContext sc; diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java index c59aa26f46694..43c08d1b3cbb3 100644 --- a/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaGraphSuite.java @@ -17,45 +17,25 @@ package org.apache.spark.graphx; -import java.io.*; -import java.net.URI; +import java.io.File; +import java.io.Serializable; import java.util.*; -import scala.Tuple2; -import scala.Tuple3; -import scala.Tuple4; - -import com.google.common.collect.Iterables; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import com.google.common.base.Optional; -import com.google.common.base.Charsets; import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.compress.DefaultCodec; -import org.apache.hadoop.mapred.SequenceFileInputFormat; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapreduce.Job; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.partial.BoundedDouble; -import org.apache.spark.partial.PartialResult; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.util.StatCounter; - import org.apache.spark.graphx.*; import org.apache.spark.graphx.api.java.*; +import org.apache.spark.storage.StorageLevel; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; +import scala.Tuple4; public class JavaGraphSuite implements Serializable { private transient JavaSparkContext sc; diff --git a/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java b/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java index 04c1f2491d17d..25fd7cfffa9b3 100644 --- a/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java +++ b/graphx/src/test/java/org/apache/spark/graphx/JavaVertexRDDSuite.java @@ -17,45 +17,23 @@ package org.apache.spark.graphx; -import java.io.*; -import java.net.URI; +import java.io.File; +import java.io.Serializable; import java.util.*; -import scala.Tuple2; -import scala.Tuple3; -import scala.Tuple4; - -import com.google.common.collect.Iterables; -import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import com.google.common.base.Optional; -import com.google.common.base.Charsets; import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.compress.DefaultCodec; -import org.apache.hadoop.mapred.SequenceFileInputFormat; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapreduce.Job; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.partial.BoundedDouble; -import org.apache.spark.partial.PartialResult; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.util.StatCounter; - import org.apache.spark.graphx.*; import org.apache.spark.graphx.api.java.*; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; public class JavaVertexRDDSuite implements Serializable { private transient JavaSparkContext sc; From b2d65904f9ee42c59793805e55a07a35462ae1f6 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Nov 2014 15:39:06 -0800 Subject: [PATCH 3/3] Fix long lines --- .../spark/graphx/api/java/JavaVertexRDD.scala | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala index cbed8922363f7..7f52b08ba6277 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/api/java/JavaVertexRDD.scala @@ -90,15 +90,17 @@ class JavaVertexRDD[VD](val vertexRDD: VertexRDD[VD])(implicit val vdTag: ClassT f: JFunction3[JLong, VD, Optional[VD2], VD3]): JavaVertexRDD[VD3] = { implicit val vd2Tag: ClassTag[VD2] = fakeClassTag implicit val vd3Tag: ClassTag[VD3] = fakeClassTag - vertexRDD.leftZipJoin(other) { (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) } + vertexRDD.leftZipJoin(other) { + (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) + } } /** - * Left joins this JavaVertexRDD with an RDD containing vertex attribute pairs. If the other RDD is - * backed by a JavaVertexRDD with the same index then the efficient [[leftZipJoin]] implementation is - * used. The resulting JavaVertexRDD contains an entry for each vertex in `this`. If `other` is - * missing any vertex in this JavaVertexRDD, `f` is passed `None`. If there are duplicates, - * the vertex is picked arbitrarily. + * Left joins this JavaVertexRDD with an RDD containing vertex attribute pairs. If the other RDD + * is backed by a JavaVertexRDD with the same index then the efficient [[leftZipJoin]] + * implementation is used. The resulting JavaVertexRDD contains an entry for each vertex in + * `this`. If `other` is missing any vertex in this JavaVertexRDD, `f` is passed `None`. If there + * are duplicates, the vertex is picked arbitrarily. * * @tparam VD2 the attribute type of the other JavaVertexRDD * @tparam VD3 the attribute type of the resulting JavaVertexRDD @@ -106,8 +108,8 @@ class JavaVertexRDD[VD](val vertexRDD: VertexRDD[VD])(implicit val vdTag: ClassT * @param other the other JavaVertexRDD with which to join * @param f the function mapping a vertex id and its attributes in this and the other vertex set * to a new vertex attribute. - * @return a JavaVertexRDD containing all the vertices in this JavaVertexRDD with the attributes emitted - * by `f`. + * @return a JavaVertexRDD containing all the vertices in this JavaVertexRDD with the attributes + * emitted by `f`. */ def leftJoin[VD2, VD3]( other: JavaPairRDD[JLong, VD2], @@ -116,12 +118,14 @@ class JavaVertexRDD[VD](val vertexRDD: VertexRDD[VD])(implicit val vdTag: ClassT implicit val vd2Tag: ClassTag[VD2] = fakeClassTag implicit val vd3Tag: ClassTag[VD3] = fakeClassTag val scalaOther: RDD[(VertexId, VD2)] = other.rdd.map(kv => (kv._1, kv._2)) - vertexRDD.leftJoin(scalaOther) { (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) } + vertexRDD.leftJoin(scalaOther) { + (vid, a, bOpt) => f.call(vid, a, JavaUtils.optionToOptional(bOpt)) + } } /** - * Efficiently inner joins this JavaVertexRDD with another JavaVertexRDD sharing the same index. See - * [[innerJoin]] for the behavior of the join. + * Efficiently inner joins this JavaVertexRDD with another JavaVertexRDD sharing the same index. + * See [[innerJoin]] for the behavior of the join. */ def innerZipJoin[U, VD2]( other: JavaVertexRDD[U],