Skip to content

Commit 0ce4e43

Browse files
srowenankurdave
authored andcommitted
SPARK-3290 [GRAPHX] No unpersist callls in SVDPlusPlus
This just unpersist()s each RDD in this code that was cache()ed. Author: Sean Owen <[email protected]> Closes #4234 from srowen/SPARK-3290 and squashes the following commits: 66c1e11 [Sean Owen] unpersist() each RDD that was cache()ed
1 parent d06d5ee commit 0ce4e43

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,22 @@ object SVDPlusPlus {
7272

7373
// construct graph
7474
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
75+
materialize(g)
76+
edges.unpersist()
7577

7678
// Calculate initial bias and norm
7779
val t0 = g.aggregateMessages[(Long, Double)](
7880
ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
7981
(g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
8082

81-
g = g.outerJoinVertices(t0) {
83+
val gJoinT0 = g.outerJoinVertices(t0) {
8284
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
8385
msg: Option[(Long, Double)]) =>
8486
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
85-
}
87+
}.cache()
88+
materialize(gJoinT0)
89+
g.unpersist()
90+
g = gJoinT0
8691

8792
def sendMsgTrainF(conf: Conf, u: Double)
8893
(ctx: EdgeContext[
@@ -114,26 +119,32 @@ object SVDPlusPlus {
114119
val t1 = g.aggregateMessages[DoubleMatrix](
115120
ctx => ctx.sendToSrc(ctx.dstAttr._2),
116121
(g1, g2) => g1.addColumnVector(g2))
117-
g = g.outerJoinVertices(t1) {
122+
val gJoinT1 = g.outerJoinVertices(t1) {
118123
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
119124
msg: Option[DoubleMatrix]) =>
120125
if (msg.isDefined) (vd._1, vd._1
121126
.addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
122-
}
127+
}.cache()
128+
materialize(gJoinT1)
129+
g.unpersist()
130+
g = gJoinT1
123131

124132
// Phase 2, update p for user nodes and q, y for item nodes
125133
g.cache()
126134
val t2 = g.aggregateMessages(
127135
sendMsgTrainF(conf, u),
128136
(g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
129137
(g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
130-
g = g.outerJoinVertices(t2) {
138+
val gJoinT2 = g.outerJoinVertices(t2) {
131139
(vid: VertexId,
132140
vd: (DoubleMatrix, DoubleMatrix, Double, Double),
133141
msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
134142
(vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
135143
vd._3 + msg.get._3, vd._4)
136-
}
144+
}.cache()
145+
materialize(gJoinT2)
146+
g.unpersist()
147+
g = gJoinT2
137148
}
138149

139150
// calculate error on training set
@@ -147,13 +158,26 @@ object SVDPlusPlus {
147158
val err = (ctx.attr - pred) * (ctx.attr - pred)
148159
ctx.sendToDst(err)
149160
}
161+
150162
g.cache()
151163
val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
152-
g = g.outerJoinVertices(t3) {
164+
val gJoinT3 = g.outerJoinVertices(t3) {
153165
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
154166
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
155-
}
167+
}.cache()
168+
materialize(gJoinT3)
169+
g.unpersist()
170+
g = gJoinT3
156171

157172
(g, u)
158173
}
174+
175+
/**
176+
* Forces materialization of a Graph by count()ing its RDDs.
177+
*/
178+
private def materialize(g: Graph[_,_]): Unit = {
179+
g.vertices.count()
180+
g.edges.count()
181+
}
182+
159183
}

0 commit comments

Comments
 (0)