diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1f2fae9c813f..3891f0044d4f 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -860,7 +860,7 @@ test_that("spark.lda with libsvm", { weights <- stats$topicTopTermsWeights vocabulary <- stats$vocabulary - expect_false(isDistributed) + expect_true(isDistributed) expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) expect_equal(vocabSize, 11) @@ -874,7 +874,7 @@ test_that("spark.lda with libsvm", { model2 <- read.ml(modelPath) stats2 <- summary(model2) - expect_false(stats2$isDistributed) + expect_true(stats2$isDistributed) expect_equal(logLikelihood, stats2$logLikelihood) expect_equal(logPerplexity, stats2$logPerplexity) expect_equal(vocabSize, stats2$vocabSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala index cbe6a705007d..e7851e148855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala @@ -122,6 +122,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] { .setK(k) .setMaxIter(maxIter) .setSubsamplingRate(subsamplingRate) + .setOptimizer(optimizer) val featureSchema = data.schema(features) val stages = featureSchema.dataType match {