diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 116d1ea70017..ae8fb1487d15 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.SparkException -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.graphx.lib._ @@ -336,6 +335,59 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) } + /** + * An additional functionality for [[GraphOps.pregel)]] using `aggregateMessages` + * + * Only the parameter `sendMsg` is different from [[GraphOps.pregel]]. + * + * @example for `sendMsg`: + * {{{ + * private def sendMessage(ctx: EdgeContext[VD, ED, A): Unit = { + * //logic code defined by yourself. + * ctx.sendToDst(aMsg1) + * ctx.sendToSrc(aMsg2) + * } + * }}} + * + * @tparam A the Pregel message type + * @param graph the input graph. + * @param initialMsg the message each vertex will receive at the on + * the first iteration + * @param maxIterations the maximum number of iterations to run for + * + * @param tripletFields which fields should be included in the [[EdgeContext]] + * passed to the `sendMsg` function. If not all fields are needed, + * specifying this can improve performance. + * + * @param vprog the user-defined vertex program which runs on each + * vertex and receives the inbound message and computes a new vertex + * value. On the first iteration the vertex program is invoked on + * all vertices and is passed the default message. On subsequent + * iterations the vertex program is only invoked on those vertices + * that receive messages. + * + * @param sendMsg a user supplied function that is applied to out + * edges of vertices that received messages in the current + * iteration + * + * @param mergeMsg a user supplied function that takes two incoming + * messages of type A and merges them into a single message of type + * A. ''This function must be commutative and associative and + * ideally the size of A should not increase.'' + * + * @return the resulting graph at the end of the computation + */ + def pregel2[A: ClassTag](graph: Graph[VD, ED], + initialMsg: A, + maxIterations: Int = Int.MaxValue, + tripletFields: TripletFields = TripletFields.All) + (vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A) + : Graph[VD, ED] = { + Pregel.apply2(graph, initialMsg, maxIterations, tripletFields)(vprog, sendMsg, mergeMsg) + } + /** * Run a dynamic version of PageRank returning a graph with vertex attributes containing the * PageRank and edge attributes containing the normalized edge weight. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 5e55620147df..914a20f95c5b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -17,9 +17,10 @@ package org.apache.spark.graphx -import scala.reflect.ClassTag import org.apache.spark.Logging +import scala.reflect.ClassTag + /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -158,4 +159,91 @@ object Pregel extends Logging { g } // end of apply + + /** + * An additional functionality for [[Pregel.apply()]] using `aggregateMessages` + * + * @tparam VD the vertex data type + * @tparam ED the edge data type + * @tparam A the Pregel message type + * @param graph the input graph. + * @param initialMsg the message each vertex will receive at the on + * the first iteration + * @param maxIterations the maximum number of iterations to run for + * + * @param tripletFields which fields should be included in the [[EdgeContext]] + * passed to the `sendMsg` function. If not all fields are needed, + * specifying this can improve performance. + * + * @param vprog the user-defined vertex program which runs on each + * vertex and receives the inbound message and computes a new vertex + * value. On the first iteration the vertex program is invoked on + * all vertices and is passed the default message. On subsequent + * iterations the vertex program is only invoked on those vertices + * that receive messages. + * + * @param sendMsg a user supplied function that is applied to out + * edges of vertices that received messages in the current + * iteration + * + * @param mergeMsg a user supplied function that takes two incoming + * messages of type A and merges them into a single message of type + * A. ''This function must be commutative and associative and + * ideally the size of A should not increase.'' + * + * @return the resulting graph at the end of the computation + */ + def apply2[VD: ClassTag, ED: ClassTag, A: ClassTag] + (graph: Graph[VD, ED], + initialMsg: A, + maxIterations: Int = Int.MaxValue, + tripletFields: TripletFields = TripletFields.All) + (vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A) + : Graph[VD, ED] = { + + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + // compute the messages + var messages = g.aggregateMessages(sendMsg, mergeMsg) + var activeMessages = messages.count() + // Loop + var prevG: Graph[VD, ED] = null + var i = 0 + while (activeMessages > 0 && i < maxIterations) { + // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. + val newVerts = g.vertices.innerJoin(messages)(vprog).cache() + // Update the graph with the new vertices. + prevG = g + g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } + g.cache() + + val oldMessages = messages + // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't + // get to send messages. We must cache messages so it can be materialized on the next line, + // allowing us to uncache the previous iteration. + messages = g.aggregateMessages(sendMsg, mergeMsg, tripletFields).cache() + // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This + // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the + // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + activeMessages = messages.count() + + logInfo("Pregel finished iteration " + i) + + // Unpersist the RDDs hidden by newly-materialized RDDs + oldMessages.unpersist(blocking=false) + newVerts.unpersist(blocking=false) + prevG.unpersistVertices(blocking=false) + prevG.edges.unpersist(blocking=false) + if (i == 0) { + graph.unpersist(blocking = false) + graph.unpersistVertices(blocking = false) + } + // count the iteration + i += 1 + } + + g + } // end of apply2 + } // end of class Pregel