Skip to content

Commit f971ce5

Browse files
dingFelix Cheung
authored andcommitted
[SPARK-5484][GRAPHX] Periodically do checkpoint in Pregel
## What changes were proposed in this pull request? Pregel-based iterative algorithms with more than ~50 iterations begin to slow down and eventually fail with a StackOverflowError due to Spark's lack of support for long lineage chains. This PR causes Pregel to checkpoint the graph periodically if the checkpoint directory is set. This PR moves PeriodicGraphCheckpointer.scala from mllib to graphx, moves PeriodicRDDCheckpointer.scala, PeriodicCheckpointer.scala from mllib to core ## How was this patch tested? unit tests, manual tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: ding <[email protected]> Author: dding3 <[email protected]> Author: Michael Allman <[email protected]> Closes #15125 from dding3/cp2_pregel. (cherry picked from commit 0a7f5f2) Signed-off-by: Felix Cheung <[email protected]>
1 parent 55834a8 commit f971ce5

File tree

13 files changed

+128
-76
lines changed

13 files changed

+128
-76
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
4141
import org.apache.spark.partial.PartialResult
4242
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
4343
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
44-
import org.apache.spark.util.collection.OpenHashMap
44+
import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
4545
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
4646
SamplingUtils}
4747

@@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag](
14201420
val mapRDDs = mapPartitions { items =>
14211421
// Priority keeps the largest elements, so let's reverse the ordering.
14221422
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
1423-
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
1423+
queue ++= collectionUtils.takeOrdered(items, num)(ord)
14241424
Iterator.single(queue)
14251425
}
14261426
if (mapRDDs.partitions.length == 0) {

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala renamed to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.impl
18+
package org.apache.spark.rdd.util
1919

2020
import org.apache.spark.SparkContext
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.storage.StorageLevel
23+
import org.apache.spark.util.PeriodicCheckpointer
2324

2425

2526
/**

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala renamed to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.impl
18+
package org.apache.spark.util
1919

2020
import scala.collection.mutable
2121

@@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
5858
* @param sc SparkContext for the Datasets given to this checkpointer
5959
* @tparam T Dataset type, such as RDD[Double]
6060
*/
61-
private[mllib] abstract class PeriodicCheckpointer[T](
61+
private[spark] abstract class PeriodicCheckpointer[T](
6262
val checkpointInterval: Int,
6363
val sc: SparkContext) extends Logging {
6464

@@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
127127
/** Get list of checkpoint files for this given Dataset */
128128
protected def getCheckpointFiles(data: T): Iterable[String]
129129

130+
/**
131+
* Call this to unpersist the Dataset.
132+
*/
133+
def unpersistDataSet(): Unit = {
134+
while (persistedQueue.nonEmpty) {
135+
val dataToUnpersist = persistedQueue.dequeue()
136+
unpersist(dataToUnpersist)
137+
}
138+
}
139+
130140
/**
131141
* Call this at the end to delete any remaining checkpoint files.
132142
*/

core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
135135
}
136136

137137
test("get a range of elements in an array not partitioned by a range partitioner") {
138-
val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
138+
val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
139139
val pairs = sc.parallelize(pairArr, 10)
140140
val range = pairs.filterByRange(200, 800).collect()
141141
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.impl
18+
package org.apache.spark.utils
1919

2020
import org.apache.hadoop.fs.Path
2121

22-
import org.apache.spark.{SparkContext, SparkFunSuite}
23-
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
2423
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
2525
import org.apache.spark.storage.StorageLevel
2626
import org.apache.spark.util.Utils
2727

2828

29-
class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
29+
class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {
3030

3131
import PeriodicRDDCheckpointerSuite._
3232

docs/configuration.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE)
21492149

21502150
</table>
21512151

2152+
### GraphX
2153+
2154+
<table class="table">
2155+
<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
2156+
<tr>
2157+
<td><code>spark.graphx.pregel.checkpointInterval</code></td>
2158+
<td>-1</td>
2159+
<td>
2160+
Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains
2161+
after lots of iterations. The checkpoint is disabled by default.
2162+
</td>
2163+
</tr>
2164+
</table>
2165+
21522166
### Deploy
21532167

21542168
<table class="table">

docs/graphx-programming-guide.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,9 @@ messages remaining.
708708
> messaging function. These constraints allow additional optimization within GraphX.
709709
710710
The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
711-
of its implementation (note calls to graph.cache have been removed):
711+
of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally
712+
checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number,
713+
say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)):
712714

713715
{% highlight scala %}
714716
class GraphOps[VD, ED] {
@@ -722,6 +724,7 @@ class GraphOps[VD, ED] {
722724
: Graph[VD, ED] = {
723725
// Receive the initial message at each vertex
724726
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
727+
725728
// compute the messages
726729
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
727730
var activeMessages = messages.count()
@@ -734,8 +737,8 @@ class GraphOps[VD, ED] {
734737
// Send new messages, skipping edges where neither side received a message. We must cache
735738
// messages so it can be materialized on the next line, allowing us to uncache the previous
736739
// iteration.
737-
messages = g.mapReduceTriplets(
738-
sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
740+
messages = GraphXUtils.mapReduceTriplets(
741+
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
739742
activeMessages = messages.count()
740743
i += 1
741744
}

graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ package org.apache.spark.graphx
1919

2020
import scala.reflect.ClassTag
2121

22+
import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
2223
import org.apache.spark.internal.Logging
24+
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
2326

2427
/**
2528
* Implements a Pregel-like bulk-synchronous message-passing API.
@@ -122,27 +125,39 @@ object Pregel extends Logging {
122125
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
123126
s" but got ${maxIterations}")
124127

125-
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
128+
val checkpointInterval = graph.vertices.sparkContext.getConf
129+
.getInt("spark.graphx.pregel.checkpointInterval", -1)
130+
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
131+
val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
132+
checkpointInterval, graph.vertices.sparkContext)
133+
graphCheckpointer.update(g)
134+
126135
// compute the messages
127136
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
137+
val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
138+
checkpointInterval, graph.vertices.sparkContext)
139+
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
128140
var activeMessages = messages.count()
141+
129142
// Loop
130143
var prevG: Graph[VD, ED] = null
131144
var i = 0
132145
while (activeMessages > 0 && i < maxIterations) {
133146
// Receive the messages and update the vertices.
134147
prevG = g
135-
g = g.joinVertices(messages)(vprog).cache()
148+
g = g.joinVertices(messages)(vprog)
149+
graphCheckpointer.update(g)
136150

137151
val oldMessages = messages
138152
// Send new messages, skipping edges where neither side received a message. We must cache
139153
// messages so it can be materialized on the next line, allowing us to uncache the previous
140154
// iteration.
141155
messages = GraphXUtils.mapReduceTriplets(
142-
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
156+
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
143157
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
144158
// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
145159
// and the vertices of g).
160+
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
146161
activeMessages = messages.count()
147162

148163
logInfo("Pregel finished iteration " + i)
@@ -154,7 +169,9 @@ object Pregel extends Logging {
154169
// count the iteration
155170
i += 1
156171
}
157-
messages.unpersist(blocking = false)
172+
messageCheckpointer.unpersistDataSet()
173+
graphCheckpointer.deleteAllCheckpoints()
174+
messageCheckpointer.deleteAllCheckpoints()
158175
g
159176
} // end of apply
160177

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala renamed to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.impl
18+
package org.apache.spark.graphx.util
1919

2020
import org.apache.spark.SparkContext
2121
import org.apache.spark.graphx.Graph
2222
import org.apache.spark.storage.StorageLevel
23+
import org.apache.spark.util.PeriodicCheckpointer
2324

2425

2526
/**
@@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
7475
* @tparam VD Vertex descriptor type
7576
* @tparam ED Edge descriptor type
7677
*
77-
* TODO: Move this out of MLlib?
7878
*/
79-
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
79+
private[spark] class PeriodicGraphCheckpointer[VD, ED](
8080
checkpointInterval: Int,
8181
sc: SparkContext)
8282
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
@@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
8787

8888
override protected def persist(data: Graph[VD, ED]): Unit = {
8989
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
90-
data.vertices.persist()
90+
/* We need to use cache because persist does not honor the default storage level requested
91+
* when constructing the graph. Only cache does that.
92+
*/
93+
data.vertices.cache()
9194
}
9295
if (data.edges.getStorageLevel == StorageLevel.NONE) {
93-
data.edges.persist()
96+
data.edges.cache()
9497
}
9598
}
9699

0 commit comments

Comments
 (0)