Skip to content

Commit 568918c

Browse files
committed
Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with subclasses for RDDs and Graphs.
1 parent daa1964 commit 568918c

File tree

6 files changed

+471
-95
lines changed

6 files changed

+471
-95
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
188188
// Update the vertex descriptors with the new counts.
189189
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
190190
graph = newGraph
191-
graphCheckpointer.updateGraph(newGraph)
191+
graphCheckpointer.update(newGraph)
192192
globalTopicTotals = computeGlobalTopicTotals()
193193
this
194194
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.impl
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.hadoop.fs.{Path, FileSystem}
23+
24+
import org.apache.spark.{SparkContext, Logging}
25+
import org.apache.spark.storage.StorageLevel
26+
27+
28+
/**
29+
* This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
30+
* (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to
31+
* the distributed data type (RDD, Graph, etc.).
32+
*
33+
* Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
34+
* as well as unpersisting and removing checkpoint files.
35+
*
36+
* Users should call update() when a new Dataset has been created,
37+
* before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are
38+
* responsible for materializing the Dataset to ensure that persisting and checkpointing actually
39+
* occur.
40+
*
41+
* When update() is called, this does the following:
42+
* - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
43+
* - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
44+
* - If using checkpointing and the checkpoint interval has been reached,
45+
* - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
46+
* - Remove older checkpoints.
47+
*
48+
* WARNINGS:
49+
* - This class should NOT be copied (since copies may conflict on which Datasets should be
50+
* checkpointed).
51+
* - This class removes checkpoint files once later Datasets have been checkpointed.
52+
* However, references to the older Datasets will still return isCheckpointed = true.
53+
*
54+
* Example usage:
55+
* {{{
56+
* val (data1, data2, data3, ...) = ...
57+
* val cp = new PeriodicCheckpointer(data1, dir, 2)
58+
* data1.count();
59+
* // persisted: data1
60+
* cp.update(data2)
61+
* data2.count();
62+
* // persisted: data1, data2
63+
* // checkpointed: data2
64+
* cp.update(data3)
65+
* data3.count();
66+
* // persisted: data1, data2, data3
67+
* // checkpointed: data2
68+
* cp.update(data4)
69+
* data4.count();
70+
* // persisted: data2, data3, data4
71+
* // checkpointed: data4
72+
* cp.update(data5)
73+
* data5.count();
74+
* // persisted: data3, data4, data5
75+
* // checkpointed: data4
76+
* }}}
77+
*
78+
* @param currentData Initial Dataset
79+
* @param checkpointInterval Datasets will be checkpointed at this interval
80+
* @param sc SparkContext for the Datasets given to this checkpointer
81+
* @tparam T Dataset type, such as RDD[Double]
82+
*/
83+
private[mllib] abstract class PeriodicCheckpointer[T](
84+
var currentData: T,
85+
val checkpointInterval: Int,
86+
val sc: SparkContext) extends Logging {
87+
88+
/** FIFO queue of past checkpointed Datasets */
89+
private val checkpointQueue = mutable.Queue[T]()
90+
91+
/** FIFO queue of past persisted Datasets */
92+
private val persistedQueue = mutable.Queue[T]()
93+
94+
/** Number of times [[update()]] has been called */
95+
private var updateCount = 0
96+
97+
update(currentData)
98+
99+
/**
100+
* Update [[currentData]] with a new Dataset. Handle persistence and checkpointing as needed.
101+
* Since this handles persistence and checkpointing, this should be called before the Dataset
102+
* has been materialized.
103+
*
104+
* @param newData New Dataset created from previous Datasets in the lineage.
105+
*/
106+
def update(newData: T): Unit = {
107+
persist(newData)
108+
persistedQueue.enqueue(newData)
109+
// We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
110+
// Users should call [[update()]] when a new Dataset has been created,
111+
// before the Dataset has been materialized.
112+
while (persistedQueue.size > 3) {
113+
val dataToUnpersist = persistedQueue.dequeue()
114+
unpersist(dataToUnpersist)
115+
}
116+
updateCount += 1
117+
118+
// Handle checkpointing (after persisting)
119+
if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
120+
// Add new checkpoint before removing old checkpoints.
121+
checkpoint(newData)
122+
checkpointQueue.enqueue(newData)
123+
// Remove checkpoints before the latest one.
124+
var canDelete = true
125+
while (checkpointQueue.size > 1 && canDelete) {
126+
// Delete the oldest checkpoint only if the next checkpoint exists.
127+
if (isCheckpointed(checkpointQueue.get(1).get)) {
128+
removeCheckpointFile()
129+
} else {
130+
canDelete = false
131+
}
132+
}
133+
}
134+
135+
currentData = newData
136+
}
137+
138+
/** Checkpoint the Dataset */
139+
def checkpoint(data: T): Unit
140+
141+
/** Return true iff the Dataset is checkpointed */
142+
def isCheckpointed(data: T): Boolean
143+
144+
/**
145+
* Persist the Dataset.
146+
* Note: This should handle checking the current [[StorageLevel]] of the Dataset.
147+
*/
148+
def persist(data: T): Unit
149+
150+
/** Unpersist the Dataset */
151+
def unpersist(data: T): Unit
152+
153+
/** Get list of checkpoint files for this given Dataset */
154+
def getCheckpointFiles(data: T): Iterable[String]
155+
156+
/**
157+
* Call this at the end to delete any remaining checkpoint files.
158+
*/
159+
def deleteAllCheckpoints(): Unit = {
160+
while (checkpointQueue.nonEmpty) {
161+
removeCheckpointFile()
162+
}
163+
}
164+
165+
/**
166+
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
167+
* This prints a warning but does not fail if the files cannot be removed.
168+
*/
169+
private def removeCheckpointFile(): Unit = {
170+
val old = checkpointQueue.dequeue()
171+
// Since the old checkpoint is not deleted by Spark, we manually delete it.
172+
val fs = FileSystem.get(sc.hadoopConfiguration)
173+
getCheckpointFiles(old).foreach { checkpointFile =>
174+
try {
175+
fs.delete(new Path(checkpointFile), true)
176+
} catch {
177+
case e: Exception =>
178+
logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
179+
checkpointFile)
180+
}
181+
}
182+
}
183+
184+
}

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala

Lines changed: 15 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717

1818
package org.apache.spark.mllib.impl
1919

20-
import scala.collection.mutable
21-
22-
import org.apache.hadoop.fs.{Path, FileSystem}
23-
24-
import org.apache.spark.Logging
2520
import org.apache.spark.graphx.Graph
2621
import org.apache.spark.storage.StorageLevel
2722

@@ -31,12 +26,12 @@ import org.apache.spark.storage.StorageLevel
3126
* Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
3227
* unpersisting and removing checkpoint files.
3328
*
34-
* Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
29+
* Users should call update() when a new graph has been created,
3530
* before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are
3631
* responsible for materializing the graph to ensure that persisting and checkpointing actually
3732
* occur.
3833
*
39-
* When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
34+
* When update() is called, this does the following:
4035
* - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
4136
* - Unpersist graphs from queue until there are at most 3 persisted graphs.
4237
* - If using checkpointing and the checkpoint interval has been reached,
@@ -73,99 +68,30 @@ import org.apache.spark.storage.StorageLevel
7368
* // checkpointed: graph4
7469
* }}}
7570
*
76-
* @param currentGraph Initial graph
71+
* @param initGraph Initial graph
7772
* @param checkpointInterval Graphs will be checkpointed at this interval
7873
* @tparam VD Vertex descriptor type
7974
* @tparam ED Edge descriptor type
8075
*
81-
* TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
76+
* TODO: Move this out of MLlib?
8277
*/
8378
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
84-
var currentGraph: Graph[VD, ED],
85-
val checkpointInterval: Int) extends Logging {
86-
87-
/** FIFO queue of past checkpointed RDDs */
88-
private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
89-
90-
/** FIFO queue of past persisted RDDs */
91-
private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
79+
initGraph: Graph[VD, ED],
80+
checkpointInterval: Int)
81+
extends PeriodicCheckpointer[Graph[VD, ED]](initGraph, checkpointInterval,
82+
initGraph.vertices.sparkContext) {
9283

93-
/** Number of times [[updateGraph()]] has been called */
94-
private var updateCount = 0
84+
override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
9585

96-
/**
97-
* Spark Context for the Graphs given to this checkpointer.
98-
* NOTE: This code assumes that only one SparkContext is used for the given graphs.
99-
*/
100-
private val sc = currentGraph.vertices.sparkContext
86+
override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
10187

102-
updateGraph(currentGraph)
103-
104-
/**
105-
* Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
106-
* Since this handles persistence and checkpointing, this should be called before the graph
107-
* has been materialized.
108-
*
109-
* @param newGraph New graph created from previous graphs in the lineage.
110-
*/
111-
def updateGraph(newGraph: Graph[VD, ED]): Unit = {
112-
if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
113-
newGraph.persist()
114-
}
115-
persistedQueue.enqueue(newGraph)
116-
// We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
117-
// Users should call [[updateGraph()]] when a new graph has been created,
118-
// before the graph has been materialized.
119-
while (persistedQueue.size > 3) {
120-
val graphToUnpersist = persistedQueue.dequeue()
121-
graphToUnpersist.unpersist(blocking = false)
122-
}
123-
updateCount += 1
124-
125-
// Handle checkpointing (after persisting)
126-
if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
127-
// Add new checkpoint before removing old checkpoints.
128-
newGraph.checkpoint()
129-
checkpointQueue.enqueue(newGraph)
130-
// Remove checkpoints before the latest one.
131-
var canDelete = true
132-
while (checkpointQueue.size > 1 && canDelete) {
133-
// Delete the oldest checkpoint only if the next checkpoint exists.
134-
if (checkpointQueue.get(1).get.isCheckpointed) {
135-
removeCheckpointFile()
136-
} else {
137-
canDelete = false
138-
}
139-
}
140-
}
141-
}
142-
143-
/**
144-
* Call this at the end to delete any remaining checkpoint files.
145-
*/
146-
def deleteAllCheckpoints(): Unit = {
147-
while (checkpointQueue.size > 0) {
148-
removeCheckpointFile()
88+
override def persist(data: Graph[VD, ED]): Unit = {
89+
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
90+
data.persist()
14991
}
15092
}
15193

152-
/**
153-
* Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
154-
* This prints a warning but does not fail if the files cannot be removed.
155-
*/
156-
private def removeCheckpointFile(): Unit = {
157-
val old = checkpointQueue.dequeue()
158-
// Since the old checkpoint is not deleted by Spark, we manually delete it.
159-
val fs = FileSystem.get(sc.hadoopConfiguration)
160-
old.getCheckpointFiles.foreach { checkpointFile =>
161-
try {
162-
fs.delete(new Path(checkpointFile), true)
163-
} catch {
164-
case e: Exception =>
165-
logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
166-
checkpointFile)
167-
}
168-
}
169-
}
94+
override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
17095

96+
override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles
17197
}

0 commit comments

Comments
 (0)