|
| 1 | +package org.apache.spark.graphx |
| 2 | + |
| 3 | +import org.apache.spark.Logging |
| 4 | + |
| 5 | +import scala.reflect.ClassTag |
| 6 | + |
| 7 | +/** |
| 8 | + * Contains additional functionality for [[Pregel]] of partially sending message. |
| 9 | + * |
| 10 | + * Execute a Pregel-like iterative vertex-parallel abstraction with current iterative number. |
| 11 | + * Part of the vertexes(called `ActiveVertexes`) send messages to their neighbours |
| 12 | + * in each iteration. |
| 13 | + * |
| 14 | + * In some cases, `ActiveVertexes` are the vertexes that their attributes do not change |
| 15 | + * between the previous and current iteration, so they need not to send message. |
| 16 | + * At first, user can set Int value(eg. `flag:Int`) with `vprog`'s first parameter `curIter` |
| 17 | + * to vertex's attribute in function `vprog`. |
| 18 | + * Then in `sendMsg`, compare the Int value (`flag`) of Vertex attribute with `curIter` of |
| 19 | + * `sendMsg`'s first parameter. |
| 20 | + * In this way, it can determine whether sending message in current iteration. |
| 21 | + * |
| 22 | + * @example sample: |
| 23 | + * {{{ |
| 24 | + * |
| 25 | + * // invoke |
| 26 | + * PregelOps[(Int, Int), Int, Map[Int, Int]](graph, isTerminal = isTerminal)( |
| 27 | + * vprog, sendMessage, mergeMessage) |
| 28 | + * |
| 29 | + * // set a `flag:Int` value of vertex attribute object |
| 30 | + * def vprog(curIter: Int, vid: VertexId, attr: (Int, Int), |
| 31 | + * messages: Map[Int, Int]): (Int, Int) = { |
| 32 | + * if (attr > 1024) { |
| 33 | + * // logic code... |
| 34 | + * // assign the curIter, the vertex can send message to its neighbors in sendMsg |
| 35 | + * (curIter, xxxx) |
| 36 | + * } else { |
| 37 | + * (0, xxxx) |
| 38 | + * } |
| 39 | + * } |
| 40 | + * |
| 41 | + * def sendMessage(curIter: Int, |
| 42 | + * ctx: EdgeContext[(Int, Int), Int, Map[Int, Int]]): Unit = { |
| 43 | + * if (curIter == 0) { |
| 44 | + * ctx.sendToDst(Map(ctx.srcAttr._2 -> -1, ctx.srcAttr.xx -> 1)) |
| 45 | + * ctx.sendToSrc(Map(ctx.dstAttr._2 -> -1, ctx.dstAttr.xx -> 1)) |
| 46 | + * } else if (curIter == ctx.srcAttr._1) { |
| 47 | + * // determine whether sending message |
| 48 | + * ctx.sendToDst(Map(ctx.srcAttr.preKCore -> -1, ctx.srcAttr.curKCore -> 1)) |
| 49 | + * ctx.sendToSrc(Map(ctx.dstAttr.preKCore -> -1, ctx.dstAttr.curKCore -> 1)) |
| 50 | + * } |
| 51 | + * } |
| 52 | + * |
| 53 | + * def isTerminal(curIter: Int, messageCount: Long): Boolean = { |
| 54 | + * if (messageCount < 10 || curIter > 1000) false else true |
| 55 | + * } |
| 56 | + * |
| 57 | + * // mergeMessage |
| 58 | + * def mergeMessage(source: Map[Int, Int], target: Map[Int, Int]): Map[Int, Int] = { |
| 59 | + * // logic code... |
| 60 | + * |
| 61 | + * target |
| 62 | + * } |
| 63 | + * |
| 64 | + * }}} |
| 65 | + * |
| 66 | + */ |
| 67 | +object PregelOps extends Logging { |
| 68 | + |
| 69 | + /** |
| 70 | + * Implementing Part of the vertexes(we call them ActiveVertexes) send messages to their |
| 71 | + * neighbours in each iteration. |
| 72 | + * |
| 73 | + * Provide a `isTerminal` to determine end up the loop with Int value `curIter` and the number |
| 74 | + * of message count number previous iterate. |
| 75 | + * |
| 76 | + * @tparam VD the vertex data type |
| 77 | + * @tparam ED the edge data type |
| 78 | + * @tparam A the Pregel message type |
| 79 | + * |
| 80 | + * @param originGraph the input graph. |
| 81 | + * |
| 82 | + * @param initialMsg the message each vertex will receive at the on |
| 83 | + * the first iteration. default is [[None]] |
| 84 | + * |
| 85 | + * @param isTerminal checking whether can finish loop |
| 86 | + * Parameter Int is the current iteration variable `curIter` |
| 87 | + * Parameter Long is the aggregate message number of previous iteration |
| 88 | + * |
| 89 | + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the |
| 90 | + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. |
| 91 | + * default is [[TripletFields.All]] |
| 92 | + * |
| 93 | + * @param vprog the user-defined vertex program which runs on each |
| 94 | + * vertex and receives the inbound message and computes a new vertex |
| 95 | + * value. On the first iteration the vertex program is invoked on |
| 96 | + * all vertices and is passed the default message. On subsequent |
| 97 | + * iterations the vertex program is only invoked on those vertices |
| 98 | + * that receive messages. |
| 99 | + * |
| 100 | + * @param sendMsg a user supplied function that is applied to out |
| 101 | + * edges of vertices that received messages in the current iteration |
| 102 | + * |
| 103 | + * @param mergeMsg a user supplied function that takes two incoming |
| 104 | + * messages of type A and merges them into a single message of type A. |
| 105 | + * ''This function must be commutative and associative and |
| 106 | + * ideally the size of A should not increase.'' |
| 107 | + * |
| 108 | + * @return the resulting graph at the end of the computation |
| 109 | + */ |
| 110 | + def apply[VD: ClassTag, ED: ClassTag, A: ClassTag] |
| 111 | + (originGraph: Graph[VD, ED], |
| 112 | + initialMsg: Option[A] = None, |
| 113 | + isTerminal: (Int, Long) => Boolean = defaultTerminal, |
| 114 | + tripletFields: TripletFields) |
| 115 | + (vprog: (Int, VertexId, VD, A) => VD, |
| 116 | + sendMsg: (Int, EdgeContext[VD, ED, A]) => Unit, |
| 117 | + mergeMsg: (A, A) => A): Graph[VD, ED] = { |
| 118 | + |
| 119 | + // init iterate 0 |
| 120 | + val initIter = 0 |
| 121 | + var graph = initialMsg match { |
| 122 | + case None => originGraph.cache() |
| 123 | + case _ => originGraph.mapVertices((vid, vdata) => vprog(initIter, vid, vdata, |
| 124 | + initialMsg.get)).cache() |
| 125 | + } |
| 126 | + |
| 127 | + // compute the messages |
| 128 | + var messageRDD = graph.aggregateMessages(sendMsg(initIter, _: EdgeContext[VD, ED, A]), mergeMsg) |
| 129 | + var activeMsgCount = messageRDD.count() |
| 130 | + |
| 131 | + // Loop, from i = 1 |
| 132 | + var i = 1 |
| 133 | + while (activeMsgCount > 0 && isTerminal(i, activeMsgCount)) { |
| 134 | + val ct = System.currentTimeMillis() |
| 135 | + val curIter = i |
| 136 | + |
| 137 | + // 1. Receive the messages. Vertices that didn't get any messages do not appear in newVerts. |
| 138 | + val newVerts = graph.vertices.innerJoin(messageRDD)( |
| 139 | + vprog(curIter, _: VertexId, _: VD, _: A)).cache() |
| 140 | + |
| 141 | + // 2. Update the graph with the new vertices. |
| 142 | + val preGraph: Graph[VD, ED] = graph |
| 143 | + graph = graph.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old)} |
| 144 | + graph.cache() |
| 145 | + |
| 146 | + val oldMessages = messageRDD |
| 147 | + // 3. aggregate message |
| 148 | + // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't |
| 149 | + // get to send messages. We must cache messages so it can be materialized on the next line, |
| 150 | + // allowing us to uncache the previous iteration. |
| 151 | + messageRDD = graph.aggregateMessages(sendMsg(curIter, _: EdgeContext[VD, ED, A]), mergeMsg, |
| 152 | + tripletFields).cache() |
| 153 | + // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This |
| 154 | + // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the |
| 155 | + // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). |
| 156 | + activeMsgCount = messageRDD.count() |
| 157 | + |
| 158 | + // Unpersist the RDDs hidden by newly-materialized RDDs |
| 159 | + oldMessages.unpersist(blocking = false) |
| 160 | + newVerts.unpersist(blocking = false) |
| 161 | + preGraph.unpersistVertices(blocking = false) |
| 162 | + preGraph.edges.unpersist(blocking = false) |
| 163 | + if (i == 1) { |
| 164 | + originGraph.unpersistVertices(blocking = false) |
| 165 | + originGraph.edges.unpersist(blocking = false) |
| 166 | + } |
| 167 | + |
| 168 | + i += 1 |
| 169 | + |
| 170 | + logInfo("{\"name\":\"pregel\", \"iterate\":" + i + ",\"cost\":" |
| 171 | + + (System.currentTimeMillis() - ct) + "}") |
| 172 | + } |
| 173 | + |
| 174 | + graph |
| 175 | + } // end of apply |
| 176 | + |
| 177 | + /** |
| 178 | + * default terminal function |
| 179 | + * @return |
| 180 | + */ |
| 181 | + private def defaultTerminal(curIter: Int, msgCount: Long): Boolean = true |
| 182 | +} |
0 commit comments