Skip to content

Commit 12c8c21

Browse files
wangmiao1981yanboliang
authored andcommitted
[SPARK-19066][SPARKR] SparkR LDA doesn't set optimizer correctly
## What changes were proposed in this pull request? spark.lda passes the optimizer "em" or "online" as a string to the backend. However, LDAWrapper doesn't set optimizer based on the value from R. Therefore, for optimizer "em", the `isDistributed` field is FALSE, which should be TRUE based on scala code. In addition, the `summary` method should bring back the results related to `DistributedLDAModel`. ## How was this patch tested? Manual tests by comparing with scala example. Modified the current unit test: fix the incorrect unit test and add necessary tests for `summary` method. Author: [email protected] <[email protected]> Closes #16464 from wangmiao1981/new.
1 parent e635cbb commit 12c8c21

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

R/pkg/R/mllib_clustering.R

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,13 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
397397
#' \item{\code{topics}}{top 10 terms and their weights of all topics}
398398
#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
399399
#' used as training set}
400+
#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set,
401+
#' given the current parameter estimates:
402+
#' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters)
403+
#' It is only for distributed LDA model (i.e., optimizer = "em")}
404+
#' \item{\code{logPrior}}{Log probability of the current parameter estimate:
405+
#' log P(topics, topic distributions for docs | Dirichlet hyperparameters)
406+
#' It is only for distributed LDA model (i.e., optimizer = "em")}
400407
#' @rdname spark.lda
401408
#' @aliases summary,LDAModel-method
402409
#' @export
@@ -413,11 +420,22 @@ setMethod("summary", signature(object = "LDAModel"),
413420
vocabSize <- callJMethod(jobj, "vocabSize")
414421
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
415422
vocabulary <- callJMethod(jobj, "vocabulary")
423+
trainingLogLikelihood <- if (isDistributed) {
424+
callJMethod(jobj, "trainingLogLikelihood")
425+
} else {
426+
NA
427+
}
428+
logPrior <- if (isDistributed) {
429+
callJMethod(jobj, "logPrior")
430+
} else {
431+
NA
432+
}
416433
list(docConcentration = unlist(docConcentration),
417434
topicConcentration = topicConcentration,
418435
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
419436
isDistributed = isDistributed, vocabSize = vocabSize,
420-
topics = topics, vocabulary = unlist(vocabulary))
437+
topics = topics, vocabulary = unlist(vocabulary),
438+
trainingLogLikelihood = trainingLogLikelihood, logPrior = logPrior)
421439
})
422440

423441
# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}

R/pkg/inst/tests/testthat/test_mllib_clustering.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,16 @@ test_that("spark.lda with libsvm", {
166166
topics <- stats$topicTopTerms
167167
weights <- stats$topicTopTermsWeights
168168
vocabulary <- stats$vocabulary
169+
trainingLogLikelihood <- stats$trainingLogLikelihood
170+
logPrior <- stats$logPrior
169171

170-
expect_false(isDistributed)
172+
expect_true(isDistributed)
171173
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
172174
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
173175
expect_equal(vocabSize, 11)
174176
expect_true(is.null(vocabulary))
177+
expect_true(trainingLogLikelihood <= 0 & !is.na(trainingLogLikelihood))
178+
expect_true(logPrior <= 0 & !is.na(logPrior))
175179

176180
# Test model save/load
177181
modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
@@ -181,11 +185,13 @@ test_that("spark.lda with libsvm", {
181185
model2 <- read.ml(modelPath)
182186
stats2 <- summary(model2)
183187

184-
expect_false(stats2$isDistributed)
188+
expect_true(stats2$isDistributed)
185189
expect_equal(logLikelihood, stats2$logLikelihood)
186190
expect_equal(logPerplexity, stats2$logPerplexity)
187191
expect_equal(vocabSize, stats2$vocabSize)
188192
expect_equal(vocabulary, stats2$vocabulary)
193+
expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood)
194+
expect_equal(logPrior, stats2$logPrior)
189195

190196
unlink(modelPath)
191197
})
@@ -202,12 +208,16 @@ test_that("spark.lda with text input", {
202208
topics <- stats$topicTopTerms
203209
weights <- stats$topicTopTermsWeights
204210
vocabulary <- stats$vocabulary
211+
trainingLogLikelihood <- stats$trainingLogLikelihood
212+
logPrior <- stats$logPrior
205213

206214
expect_false(isDistributed)
207215
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
208216
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
209217
expect_equal(vocabSize, 10)
210218
expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
219+
expect_true(is.na(trainingLogLikelihood))
220+
expect_true(is.na(logPrior))
211221

212222
# Test model save/load
213223
modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
@@ -222,6 +232,8 @@ test_that("spark.lda with text input", {
222232
expect_equal(logPerplexity, stats2$logPerplexity)
223233
expect_equal(vocabSize, stats2$vocabSize)
224234
expect_true(all.equal(vocabulary, stats2$vocabulary))
235+
expect_true(is.na(stats2$trainingLogLikelihood))
236+
expect_true(is.na(stats2$logPrior))
225237

226238
unlink(modelPath)
227239
})

R/pkg/inst/tests/testthat/test_mllib_tree.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ test_that("spark.randomForest", {
126126
63.53160, 64.05470, 65.12710, 64.30450,
127127
66.70910, 67.86125, 68.08700, 67.21865,
128128
68.89275, 69.53180, 69.39640, 69.68250),
129-
130129
tolerance = 1e-4)
131130
stats <- summary(model)
132131
expect_equal(stats$numTrees, 20)

mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._
2626

2727
import org.apache.spark.SparkException
2828
import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
29-
import org.apache.spark.ml.clustering.{LDA, LDAModel}
29+
import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LDAModel}
3030
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
3131
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
3232
import org.apache.spark.ml.param.ParamPair
@@ -45,6 +45,13 @@ private[r] class LDAWrapper private (
4545
import LDAWrapper._
4646

4747
private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
48+
49+
// The following variables were called by R side code only when the LDA model is distributed
50+
lazy private val distributedModel =
51+
pipeline.stages.last.asInstanceOf[DistributedLDAModel]
52+
lazy val trainingLogLikelihood: Double = distributedModel.trainingLogLikelihood
53+
lazy val logPrior: Double = distributedModel.logPrior
54+
4855
private val preprocessor: PipelineModel =
4956
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
5057

@@ -122,6 +129,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
122129
.setK(k)
123130
.setMaxIter(maxIter)
124131
.setSubsamplingRate(subsamplingRate)
132+
.setOptimizer(optimizer)
125133

126134
val featureSchema = data.schema(features)
127135
val stages = featureSchema.dataType match {

0 commit comments

Comments
 (0)