-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-5596] [mllib] ML model import/export for GLMs, NaiveBayes #4233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
418ba1b
b1fc5ec
64914a3
1577d70
8d46386
1496852
c495dba
2935963
d1e5882
79675d5
ee99228
a34aef5
b4ee064
12d9059
87c4eb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,17 @@ | |
|
|
||
| package org.apache.spark.mllib.classification | ||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.mllib.classification.impl.GLMClassificationModel | ||
| import org.apache.spark.mllib.linalg.BLAS.dot | ||
| import org.apache.spark.mllib.linalg.{DenseVector, Vector} | ||
| import org.apache.spark.mllib.optimization._ | ||
| import org.apache.spark.mllib.regression._ | ||
| import org.apache.spark.mllib.util.{DataValidators, MLUtils} | ||
| import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
|
|
||
| /** | ||
| * Classification model trained using Multinomial/Binary Logistic Regression. | ||
| * | ||
|
|
@@ -42,7 +45,22 @@ class LogisticRegressionModel ( | |
| override val intercept: Double, | ||
| val numFeatures: Int, | ||
| val numClasses: Int) | ||
| extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { | ||
| extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable | ||
| with Saveable { | ||
|
|
||
| if (numClasses == 2) { | ||
| require(weights.size == numFeatures, | ||
| s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + | ||
| s" numFeatures = $numFeatures, but weights.size = ${weights.size}") | ||
| } else { | ||
| val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures | ||
| val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) | ||
| require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, | ||
| s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + | ||
| s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + | ||
| s" or $weightsSizeWithIntercept (with intercept)," + | ||
| s" but was given weights of length ${weights.size}") | ||
| } | ||
|
|
||
| def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) | ||
|
|
||
|
|
@@ -60,6 +78,13 @@ class LogisticRegressionModel ( | |
| this | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. | ||
| */ | ||
| @Experimental | ||
| def getThreshold: Option[Double] = threshold | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Clears the threshold so that `predict` will output raw prediction scores. | ||
|
|
@@ -70,7 +95,9 @@ class LogisticRegressionModel ( | |
| this | ||
| } | ||
|
|
||
| override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, | ||
| override protected def predictPoint( | ||
| dataMatrix: Vector, | ||
| weightMatrix: Vector, | ||
| intercept: Double) = { | ||
| require(dataMatrix.size == numFeatures) | ||
|
|
||
|
|
@@ -126,6 +153,40 @@ class LogisticRegressionModel ( | |
| bestClass.toDouble | ||
| } | ||
| } | ||
|
|
||
| override def save(sc: SparkContext, path: String): Unit = { | ||
| GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any proposed guidelines about when to change the minor version and when the major version? I'm not expecting many versions, so I'm not sure whether it is necessary to have minor versions.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking minor versions could be used for format changes and major ones for model changes. But I'm OK with a single version number too.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no strong preference here. It is okay with the current versioning. |
||
| numFeatures, numClasses, weights, intercept, threshold) | ||
| } | ||
|
|
||
| override protected def formatVersion: String = "1.0" | ||
| } | ||
|
|
||
| object LogisticRegressionModel extends Loader[LogisticRegressionModel] { | ||
|
|
||
| override def load(sc: SparkContext, path: String): LogisticRegressionModel = { | ||
| val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) | ||
| // Hard-code class name string in case it changes in the future | ||
| val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should put a comment here about why using literal string name. |
||
| (loadedClassName, version) match { | ||
| case (className, "1.0") if className == classNameV1_0 => | ||
| val (numFeatures, numClasses) = | ||
| ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) | ||
| val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) | ||
| // numFeatures, numClasses, weights are checked in model initialization | ||
| val model = | ||
| new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) | ||
| data.threshold match { | ||
| case Some(t) => model.setThreshold(t) | ||
| case None => model.clearThreshold() | ||
| } | ||
| model | ||
| case _ => throw new Exception( | ||
| s"LogisticRegressionModel.load did not recognize model with (className, format version):" + | ||
| s"($loadedClassName, $version). Supported:\n" + | ||
| s" ($classNameV1_0, 1.0)") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We merged multinomial logistic regression. LRModel holds
numFeaturesandnumClassesnow. We need a specialized implementation and a test for it. Or for all classification models, we savenumFeaturesandnumClasses.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll save numFeatures and numClasses in all classification models' metadata. I'm going for metadata instead of data in case the model data requires multiple RDD rows (e.g., for decision tree).