Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

/**
* :: Experimental ::
Expand Down Expand Up @@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable {
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}

private[mllib] object ClassificationModel {

/**
* Helper method for loading GLM classification model metadata.
*
* @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
metadata.select("numFeatures", "numClasses").take(1)(0) match {
case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
case _ => throw new Exception(s"$modelClass unable to load" +
s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Classification model trained using Multinomial/Binary Logistic Regression.
*
Expand All @@ -42,7 +45,22 @@ class LogisticRegressionModel (
override val intercept: Double,
val numFeatures: Int,
val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

if (numClasses == 2) {
require(weights.size == numFeatures,
s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" +
s" numFeatures = $numFeatures, but weights.size = ${weights.size}")
} else {
val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures
val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1)
require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept,
s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" +
s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" +
s" or $weightsSizeWithIntercept (with intercept)," +
s" but was given weights of length ${weights.size}")
}

def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)

Expand All @@ -60,6 +78,13 @@ class LogisticRegressionModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand All @@ -70,7 +95,9 @@ class LogisticRegressionModel (
this
}

override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
require(dataMatrix.size == numFeatures)

Expand Down Expand Up @@ -126,6 +153,40 @@ class LogisticRegressionModel (
bestClass.toDouble
}
}

override def save(sc: SparkContext, path: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We merged multinomial logistic regression. LRModel holds numFeatures and numClasses now. We need a specialized implementation and a test for it. Or for all classification models, we save numFeatures and numClasses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll save numFeatures and numClasses in all classification models' metadata. I'm going for metadata instead of data in case the model data requires multiple RDD rows (e.g., for decision tree).

GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any proposed guidelines about when to change the minor version and when the major version? I'm not expecting many versions, so I'm not sure whether it is necessary to have minor versions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking minor versions could be used for format changes and major ones for model changes. But I'm OK with a single version number too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong preference here. It is okay with the current versioning.

numFeatures, numClasses, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object LogisticRegressionModel extends Loader[LogisticRegressionModel] {

override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should put a comment here about why using literal string name.

(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses)
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"LogisticRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.mllib.classification

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}

import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, Logging}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}


/**
* Model for Naive Bayes Classifiers.
Expand All @@ -36,7 +38,7 @@ import org.apache.spark.rdd.RDD
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
Expand Down Expand Up @@ -65,6 +67,85 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

override protected def formatVersion: String = "1.0"
}

object NaiveBayesModel extends Loader[NaiveBayesModel] {

import Loader._

private object SaveLoadV1_0 {

def thisFormatVersion = "1.0"

/** Hard-code class name string in case it changes in the future */
def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"

/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])

def save(sc: SparkContext, path: String, data: Data): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext._

// Create JSON metadata.
val metadataRDD =
sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
.toDataFrame("class", "version", "numFeatures", "numClasses")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))

// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
dataRDD.saveAsParquetFile(dataPath(path))
}

def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc)
// Load Parquet data.
val dataRDD = sqlContext.parquetFile(dataPath(path))
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
new NaiveBayesModel(labels, pi, theta)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class priors vector pi had ${model.pi.size} elements")
assert(model.theta.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
s" but class conditionals array theta had ${model.theta.size} elements")
assert(model.theta.forall(_.size == numFeatures),
s"NaiveBayesModel.load expected $numFeatures features," +
s" but class conditionals array theta had elements of size:" +
s" ${model.theta.map(_.size).mkString(",")}")
model
case _ => throw new Exception(
s"NaiveBayesModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

package org.apache.spark.mllib.classification

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD


/**
* Model for Support Vector Machines (SVMs).
*
Expand All @@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD
class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable {

private var threshold: Option[Double] = Some(0.0)

Expand All @@ -49,6 +53,13 @@ class SVMModel (
this
}

/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
*/
@Experimental
def getThreshold: Option[Double] = threshold

/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
Expand All @@ -69,6 +80,42 @@ class SVMModel (
case None => margin
}
}

override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
}

override protected def formatVersion: String = "1.0"
}

object SVMModel extends Loader[SVMModel] {

override def load(sc: SparkContext, path: String): SVMModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (numFeatures, numClasses) =
ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
s" was given non-matching weights vector of size ${model.weights.size}")
assert(numClasses == 2,
s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes")
data.threshold match {
case Some(t) => model.setThreshold(t)
case None => model.clearThreshold()
}
model
case _ => throw new Exception(
s"SVMModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}

/**
Expand Down
Loading