Skip to content

Commit c19152c

Browse files
committed
[SPARK-5604[MLLIB] remove checkpointDir from LDA
`checkpointDir` is a Spark global configuration. Users should set it outside LDA. This PR also hides some methods under `private[clustering] object LDA`, so they don't show up in the generated Java doc (SPARK-5610). jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#4390 from mengxr/SPARK-5604 and squashes the following commits: a34bb39 [Xiangrui Meng] remove checkpointDir from LDA
1 parent 62371ad commit c19152c

File tree

4 files changed

+24
-65
lines changed

4 files changed

+24
-65
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ object LDAExample {
134134
.setTopicConcentration(params.topicConcentration)
135135
.setCheckpointInterval(params.checkpointInterval)
136136
if (params.checkpointDir.nonEmpty) {
137-
lda.setCheckpointDir(params.checkpointDir.get)
137+
sc.setCheckpointDir(params.checkpointDir.get)
138138
}
139139
val startTime = System.nanoTime()
140140
val ldaModel = lda.run(corpus)

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 20 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ import org.apache.spark.util.Utils
5252
* - Paper which clearly explains several algorithms, including EM:
5353
* Asuncion, Welling, Smyth, and Teh.
5454
* "On Smoothing and Inference for Topic Models." UAI, 2009.
55+
*
56+
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
57+
* (Wikipedia)]]
5558
*/
5659
@Experimental
5760
class LDA private (
@@ -60,11 +63,10 @@ class LDA private (
6063
private var docConcentration: Double,
6164
private var topicConcentration: Double,
6265
private var seed: Long,
63-
private var checkpointDir: Option[String],
6466
private var checkpointInterval: Int) extends Logging {
6567

6668
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
67-
seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10)
69+
seed = Utils.random.nextLong(), checkpointInterval = 10)
6870

6971
/**
7072
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -200,50 +202,18 @@ class LDA private (
200202
this
201203
}
202204

203-
/**
204-
* Directory for storing checkpoint files during learning.
205-
* This is not necessary, but checkpointing helps with recovery (when nodes fail).
206-
* It also helps with eliminating temporary shuffle files on disk, which can be important when
207-
* LDA is run for many iterations.
208-
*/
209-
def getCheckpointDir: Option[String] = checkpointDir
210-
211-
/**
212-
* Directory for storing checkpoint files during learning.
213-
* This is not necessary, but checkpointing helps with recovery (when nodes fail).
214-
* It also helps with eliminating temporary shuffle files on disk, which can be important when
215-
* LDA is run for many iterations.
216-
*
217-
* NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value
218-
* given to LDA is ignored, and the existing directory is kept.
219-
*
220-
* (default = None)
221-
*/
222-
def setCheckpointDir(checkpointDir: String): this.type = {
223-
this.checkpointDir = Some(checkpointDir)
224-
this
225-
}
226-
227-
/**
228-
* Clear the directory for storing checkpoint files during learning.
229-
* If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still
230-
* occur; otherwise, no checkpointing will be used.
231-
*/
232-
def clearCheckpointDir(): this.type = {
233-
this.checkpointDir = None
234-
this
235-
}
236-
237205
/**
238206
* Period (in iterations) between checkpoints.
239-
* @see [[getCheckpointDir]]
240207
*/
241208
def getCheckpointInterval: Int = checkpointInterval
242209

243210
/**
244-
* Period (in iterations) between checkpoints.
245-
* (default = 10)
246-
* @see [[getCheckpointDir]]
211+
* Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
212+
* (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
213+
* important when LDA is run for many iterations. If the checkpoint directory is not set in
214+
* [[org.apache.spark.SparkContext]], this setting is ignored.
215+
*
216+
* @see [[org.apache.spark.SparkContext#setCheckpointDir]]
247217
*/
248218
def setCheckpointInterval(checkpointInterval: Int): this.type = {
249219
this.checkpointInterval = checkpointInterval
@@ -261,7 +231,7 @@ class LDA private (
261231
*/
262232
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
263233
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
264-
checkpointDir, checkpointInterval)
234+
checkpointInterval)
265235
var iter = 0
266236
val iterationTimes = Array.fill[Double](maxIterations)(0)
267237
while (iter < maxIterations) {
@@ -337,18 +307,18 @@ private[clustering] object LDA {
337307
* Vector over topics (length k) of token counts.
338308
* The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
339309
*/
340-
type TopicCounts = BDV[Double]
310+
private[clustering] type TopicCounts = BDV[Double]
341311

342-
type TokenCount = Double
312+
private[clustering] type TokenCount = Double
343313

344314
/** Term vertex IDs are {-1, -2, ..., -vocabSize} */
345-
def term2index(term: Int): Long = -(1 + term.toLong)
315+
private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)
346316

347-
def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
317+
private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
348318

349-
def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
319+
private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
350320

351-
def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
321+
private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
352322

353323
/**
354324
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
@@ -360,17 +330,16 @@ private[clustering] object LDA {
360330
* @param docConcentration "alpha"
361331
* @param topicConcentration "beta" or "eta"
362332
*/
363-
class EMOptimizer(
333+
private[clustering] class EMOptimizer(
364334
var graph: Graph[TopicCounts, TokenCount],
365335
val k: Int,
366336
val vocabSize: Int,
367337
val docConcentration: Double,
368338
val topicConcentration: Double,
369-
checkpointDir: Option[String],
370339
checkpointInterval: Int) {
371340

372341
private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
373-
graph, checkpointDir, checkpointInterval)
342+
graph, checkpointInterval)
374343

375344
def next(): EMOptimizer = {
376345
val eta = topicConcentration
@@ -468,7 +437,6 @@ private[clustering] object LDA {
468437
docConcentration: Double,
469438
topicConcentration: Double,
470439
randomSeed: Long,
471-
checkpointDir: Option[String],
472440
checkpointInterval: Int): EMOptimizer = {
473441
// For each document, create an edge (Document -> Term) for each unique term in the document.
474442
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
@@ -512,8 +480,7 @@ private[clustering] object LDA {
512480
val graph = Graph(docVertices ++ termVertices, edges)
513481
.partitionBy(PartitionStrategy.EdgePartition1D)
514482

515-
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir,
516-
checkpointInterval)
483+
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
517484
}
518485

519486
}

mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ import org.apache.spark.storage.StorageLevel
7474
* }}}
7575
*
7676
* @param currentGraph Initial graph
77-
* @param checkpointDir The directory for storing checkpoint files
7877
* @param checkpointInterval Graphs will be checkpointed at this interval
7978
* @tparam VD Vertex descriptor type
8079
* @tparam ED Edge descriptor type
@@ -83,7 +82,6 @@ import org.apache.spark.storage.StorageLevel
8382
*/
8483
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
8584
var currentGraph: Graph[VD, ED],
86-
val checkpointDir: Option[String],
8785
val checkpointInterval: Int) extends Logging {
8886

8987
/** FIFO queue of past checkpointed RDDs */
@@ -101,12 +99,6 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
10199
*/
102100
private val sc = currentGraph.vertices.sparkContext
103101

104-
// If a checkpoint directory is given, and there's no prior checkpoint directory,
105-
// then set the checkpoint directory with the given one.
106-
if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) {
107-
sc.setCheckpointDir(checkpointDir.get)
108-
}
109-
110102
updateGraph(currentGraph)
111103

112104
/**

mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext
3838
var graphsToCheck = Seq.empty[GraphToCheck]
3939

4040
val graph1 = createGraph(sc)
41-
val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
41+
val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
4242
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
4343
checkPersistence(graphsToCheck, 1)
4444

@@ -57,9 +57,9 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext
5757
val path = tempDir.toURI.toString
5858
val checkpointInterval = 2
5959
var graphsToCheck = Seq.empty[GraphToCheck]
60-
60+
sc.setCheckpointDir(path)
6161
val graph1 = createGraph(sc)
62-
val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval)
62+
val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
6363
graph1.edges.count()
6464
graph1.vertices.count()
6565
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)

0 commit comments

Comments
 (0)