Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.spark.ml.clustering

import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JObject
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
Expand All @@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector,
Vectors => OldVectors}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.VersionUtils


private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
Expand Down Expand Up @@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Values should be >= 0
* - default = uniformly (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*
* @group param
*/
@Since("1.6.0")
Expand Down Expand Up @@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* - Value should be >= 0
* - default = (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*
* @group param
*/
@Since("1.6.0")
Expand Down Expand Up @@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}

private object LDAParams {

/**
* Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
* formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
*
* @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
* [[Param]] values extracted from metadata.
* @param metadata Loaded model metadata
*/
def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = {
VersionUtils.majorMinorVersion(metadata.sparkVersion) match {
case (1, 6) =>
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val origParam =
if (paramName == "topicDistribution") "topicDistributionCol" else paramName
val param = model.getParam(origParam)
val value = param.jsonDecode(compact(render(jsonValue)))
model.set(param, value)
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
case _ => // 2.0+
DefaultParamsReader.getAndSetParams(model, metadata)
}
}
}


/**
* :: Experimental ::
Expand Down Expand Up @@ -414,11 +454,11 @@ sealed abstract class LDAModel private[ml] (
val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)

val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
dataset.toDF
dataset.toDF()
}
}

Expand Down Expand Up @@ -574,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
val vocabSize = data.getAs[Int](0)
val topicsMatrix = data.getAs[Matrix](1)
val docConcentration = data.getAs[Vector](2)
val topicConcentration = data.getAs[Double](3)
val gammaShape = data.getAs[Double](4)
val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix")
val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector,
topicConcentration: Double, gammaShape: Double) =
matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration",
"topicConcentration", "gammaShape").head()
val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
gammaShape)
val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
DefaultParamsReader.getAndSetParams(model, metadata)
LDAParams.getAndSetParams(model, metadata)
model
}
}
Expand Down Expand Up @@ -731,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(
metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
DefaultParamsReader.getAndSetParams(model, metadata)
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
oldModel, sparkSession, None)
LDAParams.getAndSetParams(model, metadata)
model
}
}
Expand Down Expand Up @@ -881,7 +919,7 @@ class LDA @Since("1.6.0") (
}

@Since("2.0.0")
object LDA extends DefaultParamsReadable[LDA] {
object LDA extends MLReadable[LDA] {

/** Get dataset for spark.mllib LDA */
private[clustering] def getOldDataset(
Expand All @@ -896,6 +934,20 @@ object LDA extends DefaultParamsReadable[LDA] {
}
}

private class LDAReader extends MLReader[LDA] {

private val className = classOf[LDA].getName

override def load(path: String): LDA = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val model = new LDA(metadata.uid)
LDAParams.getAndSetParams(model, metadata)
model
}
}

override def read: MLReader[LDA] = new LDAReader

@Since("2.0.0")
override def load(path: String): LDA = super.load(path)
}
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,9 @@ object MimaExcludes {
// SPARK-17096: Improve exception string reported through the StreamingQueryListener
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this")
) ++ Seq(
// SPARK-16240: ML persistence backward compatibility for LDA
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$")
)
}

Expand Down