@@ -72,17 +72,22 @@ object SVDPlusPlus {
7272
7373 // construct graph
7474 var g = Graph .fromEdges(edges, defaultF(conf.rank)).cache()
75+ materialize(g)
76+ edges.unpersist()
7577
7678 // Calculate initial bias and norm
7779 val t0 = g.aggregateMessages[(Long , Double )](
7880 ctx => { ctx.sendToSrc((1L , ctx.attr)); ctx.sendToDst((1L , ctx.attr)) },
7981 (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
8082
81- g = g.outerJoinVertices(t0) {
83+ val gJoinT0 = g.outerJoinVertices(t0) {
8284 (vid : VertexId , vd : (DoubleMatrix , DoubleMatrix , Double , Double ),
8385 msg : Option [(Long , Double )]) =>
8486 (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
85- }
87+ }.cache()
88+ materialize(gJoinT0)
89+ g.unpersist()
90+ g = gJoinT0
8691
8792 def sendMsgTrainF (conf : Conf , u : Double )
8893 (ctx : EdgeContext [
@@ -114,26 +119,32 @@ object SVDPlusPlus {
114119 val t1 = g.aggregateMessages[DoubleMatrix ](
115120 ctx => ctx.sendToSrc(ctx.dstAttr._2),
116121 (g1, g2) => g1.addColumnVector(g2))
117- g = g.outerJoinVertices(t1) {
122+ val gJoinT1 = g.outerJoinVertices(t1) {
118123 (vid : VertexId , vd : (DoubleMatrix , DoubleMatrix , Double , Double ),
119124 msg : Option [DoubleMatrix ]) =>
120125 if (msg.isDefined) (vd._1, vd._1
121126 .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
122- }
127+ }.cache()
128+ materialize(gJoinT1)
129+ g.unpersist()
130+ g = gJoinT1
123131
124132 // Phase 2, update p for user nodes and q, y for item nodes
125133 g.cache()
126134 val t2 = g.aggregateMessages(
127135 sendMsgTrainF(conf, u),
128136 (g1 : (DoubleMatrix , DoubleMatrix , Double ), g2 : (DoubleMatrix , DoubleMatrix , Double )) =>
129137 (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
130- g = g.outerJoinVertices(t2) {
138+ val gJoinT2 = g.outerJoinVertices(t2) {
131139 (vid : VertexId ,
132140 vd : (DoubleMatrix , DoubleMatrix , Double , Double ),
133141 msg : Option [(DoubleMatrix , DoubleMatrix , Double )]) =>
134142 (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
135143 vd._3 + msg.get._3, vd._4)
136- }
144+ }.cache()
145+ materialize(gJoinT2)
146+ g.unpersist()
147+ g = gJoinT2
137148 }
138149
139150 // calculate error on training set
@@ -147,13 +158,26 @@ object SVDPlusPlus {
147158 val err = (ctx.attr - pred) * (ctx.attr - pred)
148159 ctx.sendToDst(err)
149160 }
161+
150162 g.cache()
151163 val t3 = g.aggregateMessages[Double ](sendMsgTestF(conf, u), _ + _)
152- g = g.outerJoinVertices(t3) {
164+ val gJoinT3 = g.outerJoinVertices(t3) {
153165 (vid : VertexId , vd : (DoubleMatrix , DoubleMatrix , Double , Double ), msg : Option [Double ]) =>
154166 if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
155- }
167+ }.cache()
168+ materialize(gJoinT3)
169+ g.unpersist()
170+ g = gJoinT3
156171
157172 (g, u)
158173 }
174+
175+ /**
176+ * Forces materialization of a Graph by count()ing its RDDs.
177+ */
178+ private def materialize (g : Graph [_,_]): Unit = {
179+ g.vertices.count()
180+ g.edges.count()
181+ }
182+
159183}
0 commit comments