From acf7455553d9e30abaf85519896554fffd1264c9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 10 Apr 2016 16:43:53 +0800 Subject: [PATCH 1/2] Add unit test for EM LDA disable checkpointing --- .../org/apache/spark/ml/clustering/LDASuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a1c93891c78b..aa3bef1c56ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -289,4 +289,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("EM LDA disable checkpointing") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(-1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } From fa36cde8daa6b8893c7495b4b26054fd67b170c1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 10 Apr 2016 16:58:57 +0800 Subject: [PATCH 2/2] truncate long line --- .../test/scala/org/apache/spark/ml/clustering/LDASuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index aa3bef1c56ec..eea03c930dfd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -292,7 +292,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("EM LDA disable checkpointing") { // Checkpoint dir is set by MLlibTestSparkContext - val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(-1) + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3) + .setCheckpointInterval(-1) val model_ = lda.fit(dataset) assert(model_.isInstanceOf[DistributedLDAModel]) val model = model_.asInstanceOf[DistributedLDAModel]