From 418ba1b0dea610b3e05a6480a91aa9ab0e1ca0dc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 27 Jan 2015 13:44:02 -0800 Subject: [PATCH 01/14] Added save, load to mllib.classification.LogisticRegressionModel, plus test suite --- .../classification/LogisticRegression.scala | 72 +++++++++++++++- .../spark/mllib/util/modelImportExport.scala | 84 ++++++++++++++++++ .../LogisticRegressionSuite.scala | 86 +++++++++++++------ 3 files changed, 216 insertions(+), 26 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 282fb3ff283f..dde34195163d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,13 +17,16 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.{Importable, DataValidators, Exportable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} /** * Classification model trained using Multinomial/Binary Logistic Regression. @@ -42,7 +45,8 @@ class LogisticRegressionModel ( override val intercept: Double, val numFeatures: Int, val numClasses: Int) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Exportable { def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) @@ -60,6 +64,13 @@ class LogisticRegressionModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. @@ -126,6 +137,65 @@ class LogisticRegressionModel ( bestClass.toDouble } } + + override def save(sc: SparkContext, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + // TODO: Do we need to use a SELECT statement to make the column ordering deterministic? + // Create JSON metadata. + val metadata = + LogisticRegressionModel.Metadata(clazz = this.getClass.getName, version = Exportable.version) + val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata)) + metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. + val data = LogisticRegressionModel.Data(weights, intercept, threshold) + val dataRDD: SchemaRDD = sc.parallelize(Seq(data)) + dataRDD.saveAsParquetFile(path + "/data") + } +} + +object LogisticRegressionModel extends Importable[LogisticRegressionModel] { + + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Load JSON metadata. + val metadataRDD = sqlContext.jsonFile(path + "/metadata") + val metadataArray = metadataRDD.select("clazz".attr, "version".attr).take(1) + assert(metadataArray.size == 1, + s"Unable to load LogisticRegressionModel metadata from: ${path + "/metadata"}") + metadataArray(0) match { + case Row(clazz: String, version: String) => + assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + + s" was given model file with metadata specifying a different model class: $clazz") + assert(version == Importable.version, // only 1 version exists currently + s"LogisticRegressionModel.load did not recognize model format version: $version") + } + + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(path + "/data") + val dataArray = dataRDD.select("weights".attr, "intercept".attr, "threshold".attr).take(1) + assert(dataArray.size == 1, + s"Unable to load LogisticRegressionModel data from: ${path + "/data"}") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load LogisticRegressionModel data from: ${path + "/data"}") + val lr = data match { + case Row(weights: Vector, intercept: Double, _) => + new LogisticRegressionModel(weights, intercept) + } + if (data.isNullAt(2)) { + lr.clearThreshold() + } else { + lr.setThreshold(data.getDouble(2)) + } + lr + } + + private case class Metadata(clazz: String, version: String) + + private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala new file mode 100644 index 000000000000..f6f312a18cc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + */ +@DeveloperApi +trait Exportable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Importable.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * This directory and any intermediate directory will be created if needed. + */ + def save(sc: SparkContext, path: String): Unit + +} + +object Exportable { + + /** Current version of model import/export format. */ + val version: String = "1.0" + +} + +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be loaded from files. + * This should be inherited by an object paired with the model class. + */ +@DeveloperApi +trait Importable[Model <: Exportable] { + + /** + * Load a model from the given path. + * + * The model should have been saved by [[Exportable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + def load(sc: SparkContext, path: String): Model + +} + +object Importable { + + /** Current version of model import/export format. */ + val version: String = Exportable.version + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 3fb45938f75d..82b2fdb4608e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.classification import scala.util.control.Breaks._ +import org.apache.spark.util.Utils + import scala.util.Random import scala.collection.JavaConversions._ @@ -407,16 +409,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M * * First of all, using the following scala code to save the data into `path`. * - * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + - * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + + * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") * * Using the following R code to load the data and train the model using glmnet package. * - * library("glmnet") - * data <- read.csv("path", header=FALSE) - * label = factor(data$V1) - * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) + * library("glmnet") + * data <- read.csv("path", header=FALSE) + * label = factor(data$V1) + * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) * * The model weights of mutinomial logstic regression in R have `K` set of linear predictors * for `K` classes classification problem; however, only `K-1` set is required if the first @@ -425,25 +427,25 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M * weights. The mathematical discussion and proof can be found here: * http://en.wikipedia.org/wiki/Multinomial_logistic_regression * - * weights1 = weights$`1` - weights$`0` - * weights2 = weights$`2` - weights$`0` + * weights1 = weights$`1` - weights$`0` + * weights2 = weights$`2` - weights$`0` * - * > weights1 - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * 2.6228269 - * data.V2 -0.5837166 - * data.V3 0.9285260 - * data.V4 -0.3783612 - * data.V5 -0.8123411 - * > weights2 - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * 4.11197445 - * data.V2 -0.16918650 - * data.V3 -0.81104784 - * data.V4 -0.06463799 - * data.V5 -0.29198337 + * > weights1 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 2.6228269 + * data.V2 -0.5837166 + * data.V3 0.9285260 + * data.V4 -0.3783612 + * data.V5 -0.8123411 + * > weights2 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 4.11197445 + * data.V2 -0.16918650 + * data.V3 -0.81104784 + * data.V4 -0.06463799 + * data.V5 -0.29198337 */ val weightsR = Vectors.dense(Array( @@ -459,7 +461,41 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M // very steep curve in logistic function so that when we draw samples from distribution, it's // very easy to assign to another labels. However, this prediction result is consistent to R. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47) + } + + test("model export/import") { + val nPoints = 20 + val A = 2.0 + val B = -1.5 + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + lr.optimizer.setNumIterations(1) + val model = lr.run(testRDD) + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + Utils.deleteRecursively(tempDir) + + // Save model with threshold + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = LogisticRegressionModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + + Utils.deleteRecursively(tempDir) } } From b1fc5eca06808f2250fef75aa10c816c889dd5f1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 27 Jan 2015 14:08:10 -0800 Subject: [PATCH 02/14] small cleanups --- .../classification/LogisticRegression.scala | 18 +++++++++--------- .../spark/mllib/util/modelImportExport.scala | 10 +--------- .../LogisticRegressionSuite.scala | 6 +++--- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index dde34195163d..22e4d2ef3af3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -23,11 +23,11 @@ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} -import org.apache.spark.mllib.util.{Importable, DataValidators, Exportable} +import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + /** * Classification model trained using Multinomial/Binary Logistic Regression. * @@ -143,8 +143,8 @@ class LogisticRegressionModel ( import sqlContext._ // TODO: Do we need to use a SELECT statement to make the column ordering deterministic? // Create JSON metadata. - val metadata = - LogisticRegressionModel.Metadata(clazz = this.getClass.getName, version = Exportable.version) + val metadata = LogisticRegressionModel.Metadata( + clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") // Create Parquet data. @@ -156,6 +156,10 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Importable[LogisticRegressionModel] { + private case class Metadata(clazz: String, version: String) + + private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val sqlContext = new SQLContext(sc) import sqlContext._ @@ -169,7 +173,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Importable.version, // only 1 version exists currently + assert(version == Exportable.latestVersion, // only 1 version exists currently s"LogisticRegressionModel.load did not recognize model format version: $version") } @@ -192,10 +196,6 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { lr } - private case class Metadata(clazz: String, version: String) - - private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index f6f312a18cc5..c6ac6a45edae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.util import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi - /** * :: DeveloperApi :: * @@ -50,7 +49,7 @@ trait Exportable { object Exportable { /** Current version of model import/export format. */ - val version: String = "1.0" + val latestVersion: String = "1.0" } @@ -75,10 +74,3 @@ trait Importable[Model <: Exportable] { def load(sc: SparkContext, path: String): Model } - -object Importable { - - /** Current version of model import/export format. */ - val version: String = Exportable.version - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 82b2fdb4608e..18deca92e06a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils object LogisticRegressionSuite { @@ -481,7 +482,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - // Save model + // Save model, load it back, and compare. model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) assert(model.weights == sameModel.weights) @@ -489,12 +490,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(sameModel.getThreshold.isEmpty) Utils.deleteRecursively(tempDir) - // Save model with threshold + // Save model with threshold. model.setThreshold(0.7) model.save(sc, path) val sameModel2 = LogisticRegressionModel.load(sc, path) assert(model.getThreshold.get == sameModel2.getThreshold.get) - Utils.deleteRecursively(tempDir) } From 64914a3194298345a1efb9f4fa60550973facdab Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 27 Jan 2015 17:33:16 -0800 Subject: [PATCH 03/14] added getThreshold to SVMModel --- .../scala/org/apache/spark/mllib/classification/SVM.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index dd514ff8a37f..2e9e77dc7c05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -49,6 +49,13 @@ class SVMModel ( this } + /** + * :: Experimental :: + * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + */ + @Experimental + def getThreshold: Option[Double] = threshold + /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. From 1577d70f2fd25c1a2de4e0dfca07f46180a00385 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 30 Jan 2015 12:53:29 -0800 Subject: [PATCH 04/14] fixed issues after rebasing on master (DataFrame patch) --- .../mllib/classification/LogisticRegression.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 22e4d2ef3af3..f2a2ffb6005b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** @@ -141,15 +141,15 @@ class LogisticRegressionModel ( override def save(sc: SparkContext, path: String): Unit = { val sqlContext = new SQLContext(sc) import sqlContext._ - // TODO: Do we need to use a SELECT statement to make the column ordering deterministic? + // Create JSON metadata. val metadata = LogisticRegressionModel.Metadata( clazz = this.getClass.getName, version = Exportable.latestVersion) - val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata)) + val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") // Create Parquet data. val data = LogisticRegressionModel.Data(weights, intercept, threshold) - val dataRDD: SchemaRDD = sc.parallelize(Seq(data)) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) dataRDD.saveAsParquetFile(path + "/data") } } @@ -166,7 +166,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { // Load JSON metadata. val metadataRDD = sqlContext.jsonFile(path + "/metadata") - val metadataArray = metadataRDD.select("clazz".attr, "version".attr).take(1) + val metadataArray = metadataRDD.select("clazz", "version").take(1) assert(metadataArray.size == 1, s"Unable to load LogisticRegressionModel metadata from: ${path + "/metadata"}") metadataArray(0) match { @@ -179,7 +179,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { // Load Parquet data. val dataRDD = sqlContext.parquetFile(path + "/data") - val dataArray = dataRDD.select("weights".attr, "intercept".attr, "threshold".attr).take(1) + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load LogisticRegressionModel data from: ${path + "/data"}") val data = dataArray(0) From 8d46386d7c1bb5fa9ba6e0d51b138274c897dce5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 30 Jan 2015 15:48:40 -0800 Subject: [PATCH 05/14] Added save/load to NaiveBayes --- .../classification/LogisticRegression.scala | 2 + .../mllib/classification/NaiveBayes.scala | 64 ++++++++++++++++++- .../spark/mllib/util/modelImportExport.scala | 36 +++++++++++ .../classification/NaiveBayesSuite.scala | 42 ++++++++++-- 4 files changed, 135 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index f2a2ffb6005b..4497a81e2d90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -156,8 +156,10 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Importable[LogisticRegressionModel] { + /** Metadata for model import/export */ private case class Metadata(clazz: String, version: String) + /** Model data for model import/export */ private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) override def load(sc: SparkContext, path: String): LogisticRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a967df857bed..7616d5067e64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,11 +19,15 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.{Importable, Exportable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame, SQLContext} + +import scala.collection.mutable.ArrayBuffer + /** * Model for Naive Bayes Classifiers. @@ -36,7 +40,7 @@ import org.apache.spark.rdd.RDD class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Exportable { private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) @@ -65,6 +69,60 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } + + override def save(sc: SparkContext, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadata = NaiveBayesModel.Metadata( + clazz = this.getClass.getName, version = Exportable.latestVersion) + val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) + metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. + val data = NaiveBayesModel.Data(labels, pi, theta) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + dataRDD.saveAsParquetFile(path + "/data") + } +} + +object NaiveBayesModel extends Importable[NaiveBayesModel] { + + /** Metadata for model import/export */ + private case class Metadata(clazz: String, version: String) + + /** Model data for model import/export */ + private case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Load JSON metadata. + val metadataRDD = sqlContext.jsonFile(path + "/metadata") + val metadataArray = metadataRDD.select("clazz", "version").take(1) + assert(metadataArray.size == 1, + s"Unable to load NaiveBayesModel metadata from: ${path + "/metadata"}") + metadataArray(0) match { + case Row(clazz: String, version: String) => + assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + + s" was given model file with metadata specifying a different model class: $clazz") + assert(version == Exportable.latestVersion, // only 1 version exists currently + s"NaiveBayesModel.load did not recognize model format version: $version") + } + + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(path + "/data") + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") + val nb = data match { + case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) => + new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray) + } + nb + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index c6ac6a45edae..729ba860728d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -46,6 +46,10 @@ trait Exportable { } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object Exportable { /** Current version of model import/export format. */ @@ -74,3 +78,35 @@ trait Importable[Model <: Exportable] { def load(sc: SparkContext, path: String): Model } + +/* +/** + * :: DeveloperApi :: + * + * Trait for models and transformers which may be saved as files. + * This should be inherited by the class which implements model instances. + * + * This specializes [[Exportable]] for local models which can be stored on a single machine. + * This provides helper functionality, but developers can choose to use [[Exportable]] instead, + * even for local models. + */ +@DeveloperApi +trait LocalExportable { + + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Importable.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * This directory and any intermediate directory will be created if needed. + */ + def save(sc: SparkContext, path: String): Unit + +} +*/ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index e68fe89d6cce..f126c213ad3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.classification +import org.apache.spark.util.Utils + import scala.util.Random import org.scalatest.FunSuite @@ -58,6 +60,14 @@ object NaiveBayesSuite { LabeledPoint(y, Vectors.dense(xi)) } } + + private val smallPi = Array(0.5, 0.3, 0.2).map(math.log) + + private val smallTheta = Array( + Array(0.91, 0.03, 0.03, 0.03), // label 0 + Array(0.03, 0.91, 0.03, 0.03), // label 1 + Array(0.03, 0.03, 0.91, 0.03) // label 2 + ).map(_.map(math.log)) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -74,12 +84,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { test("Naive Bayes") { val nPoints = 10000 - val pi = Array(0.5, 0.3, 0.2).map(math.log) - val theta = Array( - Array(0.91, 0.03, 0.03, 0.03), // label 0 - Array(0.03, 0.91, 0.03, 0.03), // label 1 - Array(0.03, 0.03, 0.91, 0.03) // label 2 - ).map(_.map(math.log)) + val pi = NaiveBayesSuite.smallPi + val theta = NaiveBayesSuite.smallTheta val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) val testRDD = sc.parallelize(testData, 2) @@ -123,6 +129,30 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { NaiveBayes.train(sc.makeRDD(nan, 2)) } } + + test("model export/import") { + val nPoints = 10 + + val pi = NaiveBayesSuite.smallPi + val theta = NaiveBayesSuite.smallTheta + + val data = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) + val rdd = sc.parallelize(data, 2) + rdd.cache() + + val model = NaiveBayes.train(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + Utils.deleteRecursively(tempDir) + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { From 149685247cb53f213f505fdbea527c42a085ba5d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 31 Jan 2015 02:00:00 -0800 Subject: [PATCH 06/14] Added save/load for NaiveBayes --- .../classification/LogisticRegression.scala | 1 + .../mllib/classification/NaiveBayes.scala | 14 +++-- .../spark/mllib/util/modelImportExport.scala | 57 +++++++++---------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4497a81e2d90..093aa391dfaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -147,6 +147,7 @@ class LogisticRegressionModel ( clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. val data = LogisticRegressionModel.Data(weights, intercept, threshold) val dataRDD: DataFrame = sc.parallelize(Seq(data)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 7616d5067e64..fa46b64e80fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, StructField, StructType} import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -79,6 +81,7 @@ class NaiveBayesModel private[mllib] ( clazz = this.getClass.getName, version = Exportable.latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + // Create Parquet data. val data = NaiveBayesModel.Data(labels, pi, theta) val dataRDD: DataFrame = sc.parallelize(Seq(data)) @@ -117,11 +120,12 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") val data = dataArray(0) assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") - val nb = data match { - case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) => - new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray) - } - nb + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Importable.checkSchema[Data](dataRDD.schema) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 729ba860728d..06cd822afff5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -17,8 +17,13 @@ package org.apache.spark.mllib.util +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{DataType, StructType, StructField} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** * :: DeveloperApi :: @@ -46,11 +51,7 @@ trait Exportable { } -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object Exportable { +private[mllib] object Exportable { /** Current version of model import/export format. */ val latestVersion: String = "1.0" @@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] { } -/* -/** - * :: DeveloperApi :: - * - * Trait for models and transformers which may be saved as files. - * This should be inherited by the class which implements model instances. - * - * This specializes [[Exportable]] for local models which can be stored on a single machine. - * This provides helper functionality, but developers can choose to use [[Exportable]] instead, - * even for local models. - */ -@DeveloperApi -trait LocalExportable { +private[mllib] object Importable { /** - * Save this model to the given path. - * - * This saves: - * - human-readable (JSON) model metadata to path/metadata/ - * - Parquet formatted data to path/data/ + * Check the schema of loaded model data. * - * The model may be loaded using [[Importable.load]]. + * This checks every field in the expected schema to make sure that a field with the same + * name and DataType appears in the loaded schema. Note that this does NOT check metadata + * or containsNull. * - * @param sc Spark context used to save model data. - * @param path Path specifying the directory in which to save this model. - * This directory and any intermediate directory will be created if needed. + * @param loadedSchema Schema for model data loaded from file. + * @tparam Data Expected data type from which an expected schema can be derived. */ - def save(sc: SparkContext, path: String): Unit + def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = { + // Check schema explicitly since erasure makes it hard to use match-case for checking. + val expectedFields: Array[StructField] = + ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields + val loadedFields: Map[String, DataType] = + loadedSchema.map(field => field.name -> field.dataType).toMap + expectedFields.foreach { field => + assert(loadedFields.contains(field.name), s"Unable to parse model data." + + s" Expected field with name ${field.name} was missing in loaded schema:" + + s" ${loadedFields.mkString(", ")}") + assert(loadedFields(field.name) == field.dataType, + s"Unable to parse model data. Expected field $field but found field" + + s" with different type: ${loadedFields(field.name)}") + } + } } -*/ From c495dba32f69d26f886e55f8232746e181bcbbc1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 1 Feb 2015 16:57:42 -0800 Subject: [PATCH 07/14] made version for model import/export local to each model --- .../spark/mllib/classification/LogisticRegression.scala | 9 +++++++-- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++++-- .../org/apache/spark/mllib/util/modelImportExport.scala | 9 ++++----- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 093aa391dfaa..1b23576d59bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -144,7 +144,7 @@ class LogisticRegressionModel ( // Create JSON metadata. val metadata = LogisticRegressionModel.Metadata( - clazz = this.getClass.getName, version = Exportable.latestVersion) + clazz = this.getClass.getName, version = latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") @@ -153,6 +153,9 @@ class LogisticRegressionModel ( val dataRDD: DataFrame = sc.parallelize(Seq(data)) dataRDD.saveAsParquetFile(path + "/data") } + + override protected def latestVersion: String = LogisticRegressionModel.latestVersion + } object LogisticRegressionModel extends Importable[LogisticRegressionModel] { @@ -176,7 +179,7 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Exportable.latestVersion, // only 1 version exists currently + assert(version == latestVersion, // only 1 version exists currently s"LogisticRegressionModel.load did not recognize model format version: $version") } @@ -199,6 +202,8 @@ object LogisticRegressionModel extends Importable[LogisticRegressionModel] { lr } + override protected def latestVersion: String = "1.0" + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index fa46b64e80fb..704bda6b5578 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -78,7 +78,7 @@ class NaiveBayesModel private[mllib] ( // Create JSON metadata. val metadata = NaiveBayesModel.Metadata( - clazz = this.getClass.getName, version = Exportable.latestVersion) + clazz = this.getClass.getName, version = latestVersion) val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) metadataRDD.toJSON.saveAsTextFile(path + "/metadata") @@ -87,6 +87,8 @@ class NaiveBayesModel private[mllib] ( val dataRDD: DataFrame = sc.parallelize(Seq(data)) dataRDD.saveAsParquetFile(path + "/data") } + + override protected def latestVersion: String = NaiveBayesModel.latestVersion } object NaiveBayesModel extends Importable[NaiveBayesModel] { @@ -110,7 +112,7 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { case Row(clazz: String, version: String) => assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + s" was given model file with metadata specifying a different model class: $clazz") - assert(version == Exportable.latestVersion, // only 1 version exists currently + assert(version == latestVersion, // only 1 version exists currently s"NaiveBayesModel.load did not recognize model format version: $version") } @@ -127,6 +129,8 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray new NaiveBayesModel(labels, pi, theta) } + + override protected def latestVersion: String = "1.0" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 06cd822afff5..66490ca5ad5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -49,12 +49,8 @@ trait Exportable { */ def save(sc: SparkContext, path: String): Unit -} - -private[mllib] object Exportable { - /** Current version of model import/export format. */ - val latestVersion: String = "1.0" + protected def latestVersion: String } @@ -78,6 +74,9 @@ trait Importable[Model <: Exportable] { */ def load(sc: SparkContext, path: String): Model + /** Current version of model import/export format. */ + protected def latestVersion: String + } private[mllib] object Importable { From 29359635d2578b5da217cb48329c3ae16c7377f7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 2 Feb 2015 18:13:26 -0800 Subject: [PATCH 08/14] Added save/load and tests for most classification and regression models --- .../classification/LogisticRegression.scala | 62 ++---------- .../mllib/classification/NaiveBayes.scala | 76 ++++++++------- .../spark/mllib/classification/SVM.scala | 30 +++++- .../impl/GLMClassificationModel.scala | 94 +++++++++++++++++++ .../apache/spark/mllib/regression/Lasso.scala | 24 ++++- .../mllib/regression/LinearRegression.scala | 26 ++++- .../mllib/regression/RidgeRegression.scala | 29 +++++- .../regression/impl/GLMRegressionModel.scala | 87 +++++++++++++++++ .../mllib/tree/model/DecisionTreeModel.scala | 1 - .../spark/mllib/util/modelImportExport.scala | 36 ++++++- .../LogisticRegressionSuite.scala | 1 + .../spark/mllib/classification/SVMSuite.scala | 37 ++++++++ .../spark/mllib/regression/LassoSuite.scala | 25 +++++ .../regression/LinearRegressionSuite.scala | 22 +++++ .../regression/RidgeRegressionSuite.scala | 24 +++++ 15 files changed, 466 insertions(+), 108 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 1b23576d59bd..98ab76d88791 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -21,11 +21,11 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** @@ -139,70 +139,26 @@ class LogisticRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext._ - - // Create JSON metadata. - val metadata = LogisticRegressionModel.Metadata( - clazz = this.getClass.getName, version = latestVersion) - val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) - metadataRDD.toJSON.saveAsTextFile(path + "/metadata") - - // Create Parquet data. - val data = LogisticRegressionModel.Data(weights, intercept, threshold) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) - dataRDD.saveAsParquetFile(path + "/data") + GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold) } - override protected def latestVersion: String = LogisticRegressionModel.latestVersion + override protected def formatVersion: String = LogisticRegressionModel.formatVersion } object LogisticRegressionModel extends Importable[LogisticRegressionModel] { - /** Metadata for model import/export */ - private case class Metadata(clazz: String, version: String) - - /** Model data for model import/export */ - private case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) - override def load(sc: SparkContext, path: String): LogisticRegressionModel = { - val sqlContext = new SQLContext(sc) - import sqlContext._ - - // Load JSON metadata. - val metadataRDD = sqlContext.jsonFile(path + "/metadata") - val metadataArray = metadataRDD.select("clazz", "version").take(1) - assert(metadataArray.size == 1, - s"Unable to load LogisticRegressionModel metadata from: ${path + "/metadata"}") - metadataArray(0) match { - case Row(clazz: String, version: String) => - assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" + - s" was given model file with metadata specifying a different model class: $clazz") - assert(version == latestVersion, // only 1 version exists currently - s"LogisticRegressionModel.load did not recognize model format version: $version") - } - - // Load Parquet data. - val dataRDD = sqlContext.parquetFile(path + "/data") - val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) - assert(dataArray.size == 1, - s"Unable to load LogisticRegressionModel data from: ${path + "/data"}") - val data = dataArray(0) - assert(data.size == 3, s"Unable to load LogisticRegressionModel data from: ${path + "/data"}") - val lr = data match { - case Row(weights: Vector, intercept: Double, _) => - new LogisticRegressionModel(weights, intercept) - } - if (data.isNullAt(2)) { - lr.clearThreshold() - } else { - lr.setThreshold(data.getDouble(2)) + val data = GLMClassificationModel.loadData(sc, path, classOf[LogisticRegressionModel].getName) + val lr = new LogisticRegressionModel(data.weights, data.intercept) + data.threshold match { + case Some(t) => lr.setThreshold(t) + case None => lr.clearThreshold() } lr } - override protected def latestVersion: String = "1.0" + override protected def formatVersion: String = GLMClassificationModel.formatVersion } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 704bda6b5578..0deb58175386 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -77,60 +77,64 @@ class NaiveBayesModel private[mllib] ( import sqlContext._ // Create JSON metadata. - val metadata = NaiveBayesModel.Metadata( - clazz = this.getClass.getName, version = latestVersion) - val metadataRDD: DataFrame = sc.parallelize(Seq(metadata)) - metadataRDD.toJSON.saveAsTextFile(path + "/metadata") + val metadataRDD = + sc.parallelize(Seq((this.getClass.getName, formatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") // Create Parquet data. val data = NaiveBayesModel.Data(labels, pi, theta) val dataRDD: DataFrame = sc.parallelize(Seq(data)) - dataRDD.saveAsParquetFile(path + "/data") + dataRDD.repartition(1).saveAsParquetFile(path + "/data") } - override protected def latestVersion: String = NaiveBayesModel.latestVersion + override protected def formatVersion: String = NaiveBayesModel.formatVersion + } object NaiveBayesModel extends Importable[NaiveBayesModel] { - /** Metadata for model import/export */ - private case class Metadata(clazz: String, version: String) - /** Model data for model import/export */ private case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) - override def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) - import sqlContext._ + private object ImporterV1 extends Importer { + + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val sqlContext = new SQLContext(sc) + // Load Parquet data. + val dataRDD = sqlContext.parquetFile(path + "/data") + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Importable.checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("labels", "pi", "theta").take(1) + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") + val data = dataArray(0) + val labels = data.getAs[Seq[Double]](0).toArray + val pi = data.getAs[Seq[Double]](1).toArray + val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray + new NaiveBayesModel(labels, pi, theta) + } + } + + protected object Importer { - // Load JSON metadata. - val metadataRDD = sqlContext.jsonFile(path + "/metadata") - val metadataArray = metadataRDD.select("clazz", "version").take(1) - assert(metadataArray.size == 1, - s"Unable to load NaiveBayesModel metadata from: ${path + "/metadata"}") - metadataArray(0) match { - case Row(clazz: String, version: String) => - assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + - s" was given model file with metadata specifying a different model class: $clazz") - assert(version == latestVersion, // only 1 version exists currently - s"NaiveBayesModel.load did not recognize model format version: $version") + def get(clazz: String, version: String): Importer = { + assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + + s" was given model file with metadata specifying a different model class: $clazz") + version match { + case "1.0" => ImporterV1 + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model format version: $version." + + s" Supported versions: 1.0.") + } } + } - // Load Parquet data. - val dataRDD = sqlContext.parquetFile(path + "/data") - val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") - val data = dataArray(0) - assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") - // Check schema explicitly since erasure makes it hard to use match-case for checking. - Importable.checkSchema[Data](dataRDD.schema) - val labels = data.getAs[Seq[Double]](0).toArray - val pi = data.getAs[Seq[Double]](1).toArray - val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray - new NaiveBayesModel(labels, pi, theta) + override def load(sc: SparkContext, path: String): NaiveBayesModel = { + val (clazz, version, metadata) = Importable.loadMetadata(sc, path) + val importer = Importer.get(clazz, version) + importer.load(sc, path) } - override protected def latestVersion: String = "1.0" + override protected def formatVersion: String = "1.0" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 2e9e77dc7c05..f892ff55ea4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,13 +17,16 @@ package org.apache.spark.mllib.classification +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.DataValidators +import org.apache.spark.mllib.util.{Importable, Exportable, DataValidators} import org.apache.spark.rdd.RDD + /** * Model for Support Vector Machines (SVMs). * @@ -33,7 +36,8 @@ import org.apache.spark.rdd.RDD class SVMModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable + with Exportable { private var threshold: Option[Double] = Some(0.0) @@ -76,6 +80,28 @@ class SVMModel ( case None => margin } } + + override def save(sc: SparkContext, path: String): Unit = { + GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold) + } + + override protected def formatVersion: String = SVMModel.formatVersion +} + +object SVMModel extends Importable[SVMModel] { + + override def load(sc: SparkContext, path: String): SVMModel = { + val data = GLMClassificationModel.loadData(sc, path, classOf[SVMModel].getName) + val lr = new SVMModel(data.weights, data.intercept) + data.threshold match { + case Some(t) => lr.setThreshold(t) + case None => lr.clearThreshold() + } + lr + } + + override protected def formatVersion: String = GLMClassificationModel.formatVersion + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala new file mode 100644 index 000000000000..e3d1ddd7e96b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification.impl + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Importable +import org.apache.spark.sql.{Row, DataFrame, SQLContext} + +/** + * Helper methods for import/export of GLM classification models. + */ +private[classification] object GLMClassificationModel { + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, formatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(path + "/data") + } + + private object ImporterV1 { + + def load(sc: SparkContext, path: String, modelClass: String): Data = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(path + "/data") + val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}") + val data = dataArray(0) + assert(data.size == 3, s"Unable to load $modelClass data from: ${path + "/data"}") + val (weights, intercept) = data match { + case Row(weights: Vector, intercept: Double, _) => + (weights, intercept) + } + val threshold = if (data.isNullAt(2)) { + None + } else { + Some(data.getDouble(2)) + } + Data(weights, intercept, threshold) + } + } + + def formatVersion: String = "1.0" + + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val (clazz, version, metadata) = Importable.loadMetadata(sc, path) + // Note: This check of the class name should happen here since we may eventually want to load + // other classes (such as deprecated versions). + assert(clazz == modelClass, s"$modelClass.load" + + s" was given model file with metadata specifying a different model class: $clazz") + version match { + case "1.0" => + ImporterV1.load(sc, path, modelClass) + case _ => throw new Exception( + s"$modelClass.load did not recognize model format version: $version." + + s" Supported versions: 1.0.") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 8ecd5c6ad93c..6f59da360db5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Importable, Exportable} import org.apache.spark.rdd.RDD /** @@ -32,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Exportable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +42,28 @@ class LassoModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = LassoModel.formatVersion +} + +object LassoModel extends Importable[LassoModel] { + + override def load(sc: SparkContext, path: String): LassoModel = { + val data = GLMRegressionModel.loadData(sc, path, classOf[LassoModel].getName) + new LassoModel(data.weights, data.intercept) + } + + override protected def formatVersion: String = LassoModel.formatVersion } /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 81b6598377ff..2953fe589d83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,9 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Exportable, Importable} +import org.apache.spark.rdd.RDD /** * Regression model trained using LinearRegression. @@ -30,7 +33,8 @@ import org.apache.spark.mllib.optimization._ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable + with Exportable { override protected def predictPoint( dataMatrix: Vector, @@ -38,12 +42,28 @@ class LinearRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = LinearRegressionModel.formatVersion +} + +object LinearRegressionModel extends Importable[LinearRegressionModel] { + + override def load(sc: SparkContext, path: String): LinearRegressionModel = { + val data = GLMRegressionModel.loadData(sc, path, classOf[LinearRegressionModel].getName) + new LinearRegressionModel(data.weights, data.intercept) + } + + override protected def formatVersion: String = LinearRegressionModel.formatVersion } /** * Train a linear regression model with no regularization using Stochastic Gradient Descent. * This solves the least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + * f(weights) = 1/n ||A weights-y||^2^ * (which is the mean squared error). * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 076ba35051c9..a852421f0877 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.optimization._ +import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression.impl.GLMRegressionModel +import org.apache.spark.mllib.util.{Importable, Exportable} +import org.apache.spark.rdd.RDD + /** * Regression model trained using RidgeRegression. @@ -32,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { + with RegressionModel with Serializable with Exportable { override protected def predictPoint( dataMatrix: Vector, @@ -40,12 +43,28 @@ class RidgeRegressionModel ( intercept: Double): Double = { weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + + override def save(sc: SparkContext, path: String): Unit = { + GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + } + + override protected def formatVersion: String = RidgeRegressionModel.formatVersion +} + +object RidgeRegressionModel extends Importable[RidgeRegressionModel] { + + override def load(sc: SparkContext, path: String): RidgeRegressionModel = { + val data = GLMRegressionModel.loadData(sc, path, classOf[RidgeRegressionModel].getName) + new RidgeRegressionModel(data.weights, data.intercept) + } + + override protected def formatVersion: String = GLMRegressionModel.formatVersion } /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala new file mode 100644 index 000000000000..659c4bce3e0e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression.impl + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Importable +import org.apache.spark.sql.{Row, DataFrame, SQLContext} + +/** + * Helper methods for import/export of GLM regression models. + */ +private[regression] object GLMRegressionModel { + + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) + + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, formatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(path + "/data") + } + + private object ImporterV1 { + + def load(sc: SparkContext, path: String, modelClass: String): Data = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(path + "/data") + val dataArray = dataRDD.select("weights", "intercept").take(1) + assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}") + val data = dataArray(0) + assert(data.size == 2, s"Unable to load $modelClass data from: ${path + "/data"}") + data match { + case Row(weights: Vector, intercept: Double) => + Data(weights, intercept) + } + } + } + + def formatVersion: String = "1.0" + + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val (clazz, version, metadata) = Importable.loadMetadata(sc, path) + // Note: This check of the class name should happen here since we may eventually want to load + // other classes (such as deprecated versions). + assert(clazz == modelClass, s"$modelClass.load" + + s" was given model file with metadata specifying a different model class: $clazz") + version match { + case "1.0" => + ImporterV1.load(sc, path, modelClass) + case _ => throw new Exception( + s"$modelClass.load did not recognize model format version: $version." + + s" Supported versions: 1.0.") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a5760963068c..a25e625a4017 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -53,7 +53,6 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } - /** * Predict values for the given data set using the model trained. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 66490ca5ad5c..c70ca3f1c8e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -50,7 +50,7 @@ trait Exportable { def save(sc: SparkContext, path: String): Unit /** Current version of model import/export format. */ - protected def latestVersion: String + protected def formatVersion: String } @@ -61,7 +61,11 @@ trait Exportable { * This should be inherited by an object paired with the model class. */ @DeveloperApi -trait Importable[Model <: Exportable] { +trait Importable[M <: Exportable] { + + protected abstract class Importer { + def load(sc: SparkContext, path: String): M + } /** * Load a model from the given path. @@ -72,10 +76,12 @@ trait Importable[Model <: Exportable] { * @param path Path specifying the directory to which the model was saved. * @return Model instance */ - def load(sc: SparkContext, path: String): Model + def load(sc: SparkContext, path: String): M + + //def loadWithSchema(sc: SparkContext, path: String): (M, StructType) /** Current version of model import/export format. */ - protected def latestVersion: String + protected def formatVersion: String } @@ -107,4 +113,24 @@ private[mllib] object Importable { } } + /** + * Load metadata from the given path. + * @return (class name, version, metadata) + */ + def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = { + val sqlContext = new SQLContext(sc) + val metadata = sqlContext.jsonFile(path + "/metadata") + val (clazz, version) = try { + val metadataArray = metadata.select("class", "version").take(1) + assert(metadataArray.size == 1) + metadataArray(0) match { + case Row(clazz: String, version: String) => (clazz, version) + } + } catch { + case e: Exception => + throw new Exception(s"Unable to load model metadata from: ${path + "/metadata"}") + } + (clazz, version, metadata) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 18deca92e06a..a603deb2a5dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -476,6 +476,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val lr = new LogisticRegressionWithLBFGS().setIntercept(true) lr.optimizer.setNumIterations(1) val model = lr.run(testRDD) + model.clearThreshold() assert(model.getThreshold.isEmpty) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index a2de7fbd4138..55d431585787 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils object SVMSuite { @@ -191,6 +192,42 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { // Turning off data validation should not throw an exception new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } + + test("model export/import") { + val nPoints = 10 + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val data = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val rdd = sc.parallelize(data, 2) + rdd.cache() + + val svm = new SVMWithSGD() + svm.optimizer.setNumIterations(1) + val model = svm.run(rdd) + + model.clearThreshold() + assert(model.getThreshold.isEmpty) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + Utils.deleteRecursively(tempDir) + + // Save model with threshold. + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + Utils.deleteRecursively(tempDir) + } } class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2668dcc14a84..70912fa5c065 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils class LassoSuite extends FunSuite with MLlibTestSparkContext { @@ -115,6 +116,30 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("model export/import") { + // Create dataset + val nPoints = 10 + val A = 2.0 + val B = -1.5 + val C = 1.0e-2 + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) + val testRDD = sc.parallelize(testData, 2).cache() + + // Train model + val ls = new LassoWithSGD() + ls.optimizer.setNumIterations(1) + val model = ls.run(testRDD) + + // Save model, load it back, and compare. + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + Utils.deleteRecursively(tempDir) + } } class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 864622a9296a..a9ef2d691885 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -124,6 +125,27 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { validatePrediction( sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } + + test("model export/import") { + // Create dataset + val rdd = sc.parallelize( + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 10, 42), 2) + + // Train model + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1) + val model = linReg.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + Utils.deleteRecursively(tempDir) + } } class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 18d3bf5ea4ec..66391d59d719 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.util.Utils class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { @@ -75,6 +76,29 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } + + test("model export/import") { + // Create dataset + val numExamples = 20 + val numFeatures = 4 + val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) + val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, numExamples, 42, 10.0) + val rdd = sc.parallelize(data, 2).cache() + + // Train model + val lr = new RidgeRegressionWithSGD() + lr.optimizer.setNumIterations(1) + val model = lr.run(rdd) + + // Save model, load it back, and compare. + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + Utils.deleteRecursively(tempDir) + } } class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { From d1e58827c01d3bb6533bef3dceb22e72896802b7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 2 Feb 2015 18:18:33 -0800 Subject: [PATCH 09/14] organized imports --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 6 +----- .../scala/org/apache/spark/mllib/classification/SVM.scala | 2 +- .../scala/org/apache/spark/mllib/regression/Lasso.scala | 2 +- .../org/apache/spark/mllib/regression/RidgeRegression.scala | 2 +- .../spark/mllib/regression/impl/GLMRegressionModel.scala | 2 +- .../org/apache/spark/mllib/util/modelImportExport.scala | 5 +++-- .../apache/spark/mllib/classification/NaiveBayesSuite.scala | 4 ++-- 7 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 0deb58175386..91fac686542a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,17 +18,13 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, StructField, StructType} import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Importable, Exportable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame, SQLContext} - -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.{DataFrame, SQLContext} /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index f892ff55ea4d..cadc57bce9f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{Importable, Exportable, DataValidators} +import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 6f59da360db5..e741929bf7c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Importable, Exportable} +import org.apache.spark.mllib.util.{Exportable, Importable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a852421f0877..4e63b43a8335 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Importable, Exportable} +import org.apache.spark.mllib.util.{Exportable, Importable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 659c4bce3e0e..f475953263f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression.impl import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Importable -import org.apache.spark.sql.{Row, DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** * Helper methods for import/export of GLM regression models. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index c70ca3f1c8e6..71a99dfa2820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{DataType, StructType, StructField} -import org.apache.spark.sql.{DataFrame, Row, SQLContext} + /** * :: DeveloperApi :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index f126c213ad3e..276c3e5468f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.classification -import org.apache.spark.util.Utils - import scala.util.Random import org.scalatest.FunSuite @@ -27,6 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.util.Utils + object NaiveBayesSuite { From 79675d53717dfa66b13cdbddf9356b61946f8a87 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 2 Feb 2015 18:57:34 -0800 Subject: [PATCH 10/14] cleanups in LogisticRegression after rebasing after multinomial PR --- .../spark/mllib/classification/LogisticRegression.scala | 4 ++-- .../mllib/classification/LogisticRegressionSuite.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 98ab76d88791..802d10999b5c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} -import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable, MLUtils} +import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index a603deb2a5dc..9d6a507f4d60 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.classification -import scala.util.control.Breaks._ -import org.apache.spark.util.Utils - -import scala.util.Random import scala.collection.JavaConversions._ +import scala.util.Random +import scala.util.control.Breaks._ import org.scalatest.FunSuite import org.scalatest.Matchers @@ -32,6 +30,7 @@ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkCont import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils + object LogisticRegressionSuite { def generateLogisticInputAsList( From ee99228ac7904f410210b5e9118f5e88ea280164 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 2 Feb 2015 19:59:14 -0800 Subject: [PATCH 11/14] scala style fix --- .../scala/org/apache/spark/mllib/util/modelImportExport.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala index 71a99dfa2820..b1201ff54721 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala @@ -79,7 +79,7 @@ trait Importable[M <: Exportable] { */ def load(sc: SparkContext, path: String): M - //def loadWithSchema(sc: SparkContext, path: String): (M, StructType) + // def loadWithSchema(sc: SparkContext, path: String): (M, StructType) /** Current version of model import/export format. */ protected def formatVersion: String From b4ee0643b2981517ba02f036d4a053883ca1fe37 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 3 Feb 2015 20:03:16 -0800 Subject: [PATCH 12/14] Reorganized save/load for regression and classification. Renamed concepts to Saveable, Loader --- .../classification/LogisticRegression.scala | 36 +++++---- .../mllib/classification/NaiveBayes.scala | 80 +++++++++---------- .../spark/mllib/classification/SVM.scala | 35 ++++---- .../impl/GLMClassificationModel.scala | 70 +++++++--------- .../apache/spark/mllib/regression/Lasso.scala | 25 +++--- .../mllib/regression/LinearRegression.scala | 25 +++--- .../mllib/regression/RidgeRegression.scala | 25 +++--- .../regression/impl/GLMRegressionModel.scala | 64 ++++++--------- ...ImportExport.scala => modelSaveLoad.scala} | 26 +++--- .../LogisticRegressionSuite.scala | 54 +++++++------ .../classification/NaiveBayesSuite.scala | 2 +- .../spark/mllib/classification/SVMSuite.scala | 3 +- .../spark/mllib/regression/LassoSuite.scala | 2 +- .../regression/LinearRegressionSuite.scala | 2 +- .../regression/RidgeRegressionSuite.scala | 2 +- 15 files changed, 224 insertions(+), 227 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/util/{modelImportExport.scala => modelSaveLoad.scala} (87%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index f4fac8cd2d5c..b2fa23decd0c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable} +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD @@ -46,7 +46,7 @@ class LogisticRegressionModel ( val numFeatures: Int, val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Exportable { + with Saveable { def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) @@ -139,27 +139,33 @@ class LogisticRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold) + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + weights, intercept, threshold) } - override protected def formatVersion: String = LogisticRegressionModel.formatVersion - + override protected def formatVersion: String = "1.0" } -object LogisticRegressionModel extends Importable[LogisticRegressionModel] { +object LogisticRegressionModel extends Loader[LogisticRegressionModel] { override def load(sc: SparkContext, path: String): LogisticRegressionModel = { - val data = GLMClassificationModel.loadData(sc, path, classOf[LogisticRegressionModel].getName) - val lr = new LogisticRegressionModel(data.weights, data.intercept) - data.threshold match { - case Some(t) => lr.setThreshold(t) - case None => lr.clearThreshold() + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new LogisticRegressionModel(data.weights, data.intercept) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"LogisticRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } - lr } - - override protected def formatVersion: String = GLMClassificationModel.formatVersion - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 91fac686542a..7412759cdb33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -22,7 +22,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgma import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{Importable, Exportable} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} @@ -38,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} class NaiveBayesModel private[mllib] ( val labels: Array[Double], val pi: Array[Double], - val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Exportable { + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable { private val brzPi = new BDV[Double](pi) private val brzTheta = new BDM[Double](theta.length, theta(0).length) @@ -69,37 +69,44 @@ class NaiveBayesModel private[mllib] ( } override def save(sc: SparkContext, path: String): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext._ - - // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((this.getClass.getName, formatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") - - // Create Parquet data. - val data = NaiveBayesModel.Data(labels, pi, theta) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) - dataRDD.repartition(1).saveAsParquetFile(path + "/data") + val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta) + NaiveBayesModel.SaveLoadV1_0.save(sc, path, data) } - override protected def formatVersion: String = NaiveBayesModel.formatVersion - + override protected def formatVersion: String = "1.0" } -object NaiveBayesModel extends Importable[NaiveBayesModel] { +object NaiveBayesModel extends Loader[NaiveBayesModel] { + + private object SaveLoadV1_0 { + + def thisFormatVersion = "1.0" + + def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" - /** Model data for model import/export */ - private case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) + /** Model data for model import/export */ + case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) - private object ImporterV1 extends Importer { + def save(sc: SparkContext, path: String, data: Data): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + + // Create Parquet data. + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + dataRDD.repartition(1).saveAsParquetFile(path + "/data") + } - override def load(sc: SparkContext, path: String): NaiveBayesModel = { + def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. val dataRDD = sqlContext.parquetFile(path + "/data") // Check schema explicitly since erasure makes it hard to use match-case for checking. - Importable.checkSchema[Data](dataRDD.schema) + Loader.checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") val data = dataArray(0) @@ -110,27 +117,18 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] { } } - protected object Importer { - - def get(clazz: String, version: String): Importer = { - assert(clazz == classOf[NaiveBayesModel].getName, s"NaiveBayesModel.load" + - s" was given model file with metadata specifying a different model class: $clazz") - version match { - case "1.0" => ImporterV1 - case _ => throw new Exception( - s"NaiveBayesModel.load did not recognize model format version: $version." + - s" Supported versions: 1.0.") - } - } - } - override def load(sc: SparkContext, path: String): NaiveBayesModel = { - val (clazz, version, metadata) = Importable.loadMetadata(sc, path) - val importer = Importer.get(clazz, version) - importer.load(sc, path) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => throw new Exception( + s"NaiveBayesModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } - - override protected def formatVersion: String = "1.0" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index cadc57bce9f4..baa3d27f99b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Exportable, Importable} +import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD @@ -37,7 +37,7 @@ class SVMModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Exportable { + with Saveable { private var threshold: Option[Double] = Some(0.0) @@ -82,26 +82,33 @@ class SVMModel ( } override def save(sc: SparkContext, path: String): Unit = { - GLMClassificationModel.save(sc, path, this.getClass.getName, weights, intercept, threshold) + GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, + weights, intercept, threshold) } - override protected def formatVersion: String = SVMModel.formatVersion + override protected def formatVersion: String = "1.0" } -object SVMModel extends Importable[SVMModel] { +object SVMModel extends Loader[SVMModel] { override def load(sc: SparkContext, path: String): SVMModel = { - val data = GLMClassificationModel.loadData(sc, path, classOf[SVMModel].getName) - val lr = new SVMModel(data.weights, data.intercept) - data.threshold match { - case Some(t) => lr.setThreshold(t) - case None => lr.clearThreshold() + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val model = new SVMModel(data.weights, data.intercept) + data.threshold match { + case Some(t) => model.setThreshold(t) + case None => model.clearThreshold() + } + model + case _ => throw new Exception( + s"SVMModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } - lr } - - override protected def formatVersion: String = GLMClassificationModel.formatVersion - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index e3d1ddd7e96b..eb7f20dcf2a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -19,42 +19,43 @@ package org.apache.spark.mllib.classification.impl import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.util.Importable -import org.apache.spark.sql.{Row, DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * Helper methods for import/export of GLM classification models. + * Helper class for import/export of GLM classification models. */ private[classification] object GLMClassificationModel { - /** Model data for model import/export */ - case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + object SaveLoadV1_0 { - def save( - sc: SparkContext, - path: String, - modelClass: String, - weights: Vector, - intercept: Double, - threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext._ + def thisFormatVersion = "1.0" - // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((modelClass, formatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) - // Create Parquet data. - val data = Data(weights, intercept, threshold) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) - // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(path + "/data") - } + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double, + threshold: Option[Double]): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ - private object ImporterV1 { + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + + // Create Parquet data. + val data = Data(weights, intercept, threshold) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(path + "/data") + } - def load(sc: SparkContext, path: String, modelClass: String): Data = { + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val sqlContext = new SQLContext(sc) val dataRDD = sqlContext.parquetFile(path + "/data") val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) @@ -74,21 +75,4 @@ private[classification] object GLMClassificationModel { } } - def formatVersion: String = "1.0" - - def loadData(sc: SparkContext, path: String, modelClass: String): Data = { - val (clazz, version, metadata) = Importable.loadMetadata(sc, path) - // Note: This check of the class name should happen here since we may eventually want to load - // other classes (such as deprecated versions). - assert(clazz == modelClass, s"$modelClass.load" + - s" was given model file with metadata specifying a different model class: $clazz") - version match { - case "1.0" => - ImporterV1.load(sc, path, modelClass) - case _ => throw new Exception( - s"$modelClass.load did not recognize model format version: $version." + - s" Supported versions: 1.0.") - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index e741929bf7c7..fd674d5e0701 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Exportable, Importable} +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -34,7 +34,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Exportable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -44,20 +44,27 @@ class LassoModel ( } override def save(sc: SparkContext, path: String): Unit = { - GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } - override protected def formatVersion: String = LassoModel.formatVersion + override protected def formatVersion: String = "1.0" } -object LassoModel extends Importable[LassoModel] { +object LassoModel extends Loader[LassoModel] { override def load(sc: SparkContext, path: String): LassoModel = { - val data = GLMRegressionModel.loadData(sc, path, classOf[LassoModel].getName) - new LassoModel(data.weights, data.intercept) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + new LassoModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LassoModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } - - override protected def formatVersion: String = LassoModel.formatVersion } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 2953fe589d83..0333a97bac9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Exportable, Importable} +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD /** @@ -34,7 +34,7 @@ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable - with Exportable { + with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -44,20 +44,27 @@ class LinearRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } - override protected def formatVersion: String = LinearRegressionModel.formatVersion + override protected def formatVersion: String = "1.0" } -object LinearRegressionModel extends Importable[LinearRegressionModel] { +object LinearRegressionModel extends Loader[LinearRegressionModel] { override def load(sc: SparkContext, path: String): LinearRegressionModel = { - val data = GLMRegressionModel.loadData(sc, path, classOf[LinearRegressionModel].getName) - new LinearRegressionModel(data.weights, data.intercept) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + new LinearRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"LinearRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } - - override protected def formatVersion: String = LinearRegressionModel.formatVersion } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 4e63b43a8335..6ffe8ec6230c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Exportable, Importable} +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD @@ -35,7 +35,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Exportable { + with RegressionModel with Serializable with Saveable { override protected def predictPoint( dataMatrix: Vector, @@ -45,20 +45,27 @@ class RidgeRegressionModel ( } override def save(sc: SparkContext, path: String): Unit = { - GLMRegressionModel.save(sc, path, this.getClass.getName, weights, intercept) + GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } - override protected def formatVersion: String = RidgeRegressionModel.formatVersion + override protected def formatVersion: String = "1.0" } -object RidgeRegressionModel extends Importable[RidgeRegressionModel] { +object RidgeRegressionModel extends Loader[RidgeRegressionModel] { override def load(sc: SparkContext, path: String): RidgeRegressionModel = { - val data = GLMRegressionModel.loadData(sc, path, classOf[RidgeRegressionModel].getName) - new RidgeRegressionModel(data.weights, data.intercept) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + new RidgeRegressionModel(data.weights, data.intercept) + case _ => throw new Exception( + s"RidgeRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } } - - override protected def formatVersion: String = GLMRegressionModel.formatVersion } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index f475953263f3..f5a3ca38f77d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.regression.impl import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.util.Importable import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** @@ -27,33 +26,35 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext} */ private[regression] object GLMRegressionModel { - /** Model data for model import/export */ - case class Data(weights: Vector, intercept: Double) + object SaveLoadV1_0 { - def save( - sc: SparkContext, - path: String, - modelClass: String, - weights: Vector, - intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext._ + def thisFormatVersion = "1.0" - // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((modelClass, formatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + /** Model data for model import/export */ + case class Data(weights: Vector, intercept: Double) - // Create Parquet data. - val data = Data(weights, intercept) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) - // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(path + "/data") - } + def save( + sc: SparkContext, + path: String, + modelClass: String, + weights: Vector, + intercept: Double): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext._ - private object ImporterV1 { + // Create JSON metadata. + val metadataRDD = + sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version") + metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + + // Create Parquet data. + val data = Data(weights, intercept) + val dataRDD: DataFrame = sc.parallelize(Seq(data)) + // TODO: repartition with 1 partition after SPARK-5532 gets fixed + dataRDD.saveAsParquetFile(path + "/data") + } - def load(sc: SparkContext, path: String, modelClass: String): Data = { + def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val sqlContext = new SQLContext(sc) val dataRDD = sqlContext.parquetFile(path + "/data") val dataArray = dataRDD.select("weights", "intercept").take(1) @@ -67,21 +68,4 @@ private[regression] object GLMRegressionModel { } } - def formatVersion: String = "1.0" - - def loadData(sc: SparkContext, path: String, modelClass: String): Data = { - val (clazz, version, metadata) = Importable.loadMetadata(sc, path) - // Note: This check of the class name should happen here since we may eventually want to load - // other classes (such as deprecated versions). - assert(clazz == modelClass, s"$modelClass.load" + - s" was given model file with metadata specifying a different model class: $clazz") - version match { - case "1.0" => - ImporterV1.load(sc, path, modelClass) - case _ => throw new Exception( - s"$modelClass.load did not recognize model format version: $version." + - s" Supported versions: 1.0.") - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala similarity index 87% rename from mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala rename to mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index b1201ff54721..20d4a9537f3e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DataType, StructType, StructField} * This should be inherited by the class which implements model instances. */ @DeveloperApi -trait Exportable { +trait Saveable { /** * Save this model to the given path. @@ -42,7 +42,7 @@ trait Exportable { * - human-readable (JSON) model metadata to path/metadata/ * - Parquet formatted data to path/data/ * - * The model may be loaded using [[Importable.load]]. + * The model may be loaded using [[Loader.load]]. * * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. @@ -50,7 +50,7 @@ trait Exportable { */ def save(sc: SparkContext, path: String): Unit - /** Current version of model import/export format. */ + /** Current version of model save/load format. */ protected def formatVersion: String } @@ -58,20 +58,16 @@ trait Exportable { /** * :: DeveloperApi :: * - * Trait for models and transformers which may be loaded from files. + * Trait for classes which can load models and transformers from files. * This should be inherited by an object paired with the model class. */ @DeveloperApi -trait Importable[M <: Exportable] { - - protected abstract class Importer { - def load(sc: SparkContext, path: String): M - } +trait Loader[M <: Saveable] { /** * Load a model from the given path. * - * The model should have been saved by [[Exportable.save]]. + * The model should have been saved by [[Saveable.save]]. * * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. @@ -79,14 +75,12 @@ trait Importable[M <: Exportable] { */ def load(sc: SparkContext, path: String): M - // def loadWithSchema(sc: SparkContext, path: String): (M, StructType) - - /** Current version of model import/export format. */ - protected def formatVersion: String - } -private[mllib] object Importable { +/** + * Helper methods for loading models from files. + */ +private[mllib] object Loader { /** * Check the schema of loaded model data. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 9d6a507f4d60..fd33f7c0daa3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -409,16 +409,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M * * First of all, using the following scala code to save the data into `path`. * - * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + - * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " + + * x.features(2) + ", " + x.features(3)).saveAsTextFile("path") * * Using the following R code to load the data and train the model using glmnet package. * - * library("glmnet") - * data <- read.csv("path", header=FALSE) - * label = factor(data$V1) - * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) + * library("glmnet") + * data <- read.csv("path", header=FALSE) + * label = factor(data$V1) + * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) * * The model weights of mutinomial logstic regression in R have `K` set of linear predictors * for `K` classes classification problem; however, only `K-1` set is required if the first @@ -427,25 +427,25 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M * weights. The mathematical discussion and proof can be found here: * http://en.wikipedia.org/wiki/Multinomial_logistic_regression * - * weights1 = weights$`1` - weights$`0` - * weights2 = weights$`2` - weights$`0` + * weights1 = weights$`1` - weights$`0` + * weights2 = weights$`2` - weights$`0` * - * > weights1 - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * 2.6228269 - * data.V2 -0.5837166 - * data.V3 0.9285260 - * data.V4 -0.3783612 - * data.V5 -0.8123411 - * > weights2 - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * 4.11197445 - * data.V2 -0.16918650 - * data.V3 -0.81104784 - * data.V4 -0.06463799 - * data.V5 -0.29198337 + * > weights1 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 2.6228269 + * data.V2 -0.5837166 + * data.V3 0.9285260 + * data.V4 -0.3783612 + * data.V5 -0.8123411 + * > weights2 + * 5 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * 4.11197445 + * data.V2 -0.16918650 + * data.V3 -0.81104784 + * data.V4 -0.06463799 + * data.V5 -0.29198337 */ val weightsR = Vectors.dense(Array( @@ -461,9 +461,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M // very steep curve in logistic function so that when we draw samples from distribution, it's // very easy to assign to another labels. However, this prediction result is consistent to R. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47) + } - test("model export/import") { + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. val nPoints = 20 val A = 2.0 val B = -1.5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 276c3e5468f8..a86d7d5131e1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -130,7 +130,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } - test("model export/import") { + test("model save/load") { val nPoints = 10 val pi = NaiveBayesSuite.smallPi diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 55d431585787..41d8d8aa682e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -193,7 +193,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } - test("model export/import") { + test("model save/load") { + // NOTE: This will need to be generalized once there are multiple model format versions. val nPoints = 10 val A = 0.01 val B = -1.5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 70912fa5c065..2c1153ef0599 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -117,7 +117,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } - test("model export/import") { + test("model save/load") { // Create dataset val nPoints = 10 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index a9ef2d691885..bbf4a1226fe5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -126,7 +126,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) } - test("model export/import") { + test("model save/load") { // Create dataset val rdd = sc.parallelize( LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 10, 42), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 66391d59d719..6ed36fe3f7d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -77,7 +77,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } - test("model export/import") { + test("model save/load") { // Create dataset val numExamples = 20 val numFeatures = 4 From 12d905970899a5466f1b23a4ccafe2887438c861 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 4 Feb 2015 15:24:01 -0800 Subject: [PATCH 13/14] Many cleanups after code review. Major changes: Storing numFeatures, numClasses in model metadata. Improvements to unit tests --- .../classification/ClassificationModel.scala | 20 ++++++ .../classification/LogisticRegression.scala | 27 ++++++- .../mllib/classification/NaiveBayes.scala | 36 +++++++--- .../spark/mllib/classification/SVM.scala | 9 ++- .../impl/GLMClassificationModel.scala | 33 ++++++--- .../apache/spark/mllib/regression/Lasso.scala | 4 +- .../mllib/regression/LinearRegression.scala | 4 +- .../mllib/regression/RegressionModel.scala | 20 ++++++ .../mllib/regression/RidgeRegression.scala | 4 +- .../regression/impl/GLMRegressionModel.scala | 31 +++++--- .../spark/mllib/util/modelSaveLoad.scala | 12 +++- .../LogisticRegressionSuite.scala | 72 +++++++++++++------ .../classification/NaiveBayesSuite.scala | 30 ++++---- .../spark/mllib/classification/SVMSuite.scala | 44 ++++++------ .../spark/mllib/regression/LassoSuite.scala | 35 +++++---- .../regression/LinearRegressionSuite.scala | 28 ++++---- .../regression/RidgeRegressionSuite.scala | 34 ++++----- 17 files changed, 299 insertions(+), 144 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index b7a1d90d24d7..348c1e8760a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -20,7 +20,9 @@ package org.apache.spark.mllib.classification import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} /** * :: Experimental :: @@ -53,3 +55,21 @@ trait ClassificationModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object ClassificationModel { + + /** + * Helper method for loading GLM classification model metadata. + * + * @param modelClass String name for model class (used for error messages) + * @return (numFeatures, numClasses) + */ + def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = { + metadata.select("numFeatures", "numClasses").take(1)(0) match { + case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses) + case _ => throw new Exception(s"$modelClass unable to load" + + s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index b2fa23decd0c..5c9feb6fb269 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -48,6 +48,20 @@ class LogisticRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable { + if (numClasses == 2) { + require(weights.size == numFeatures, + s"LogisticRegressionModel with numClasses = 2 was given non-matching values:" + + s" numFeatures = $numFeatures, but weights.size = ${weights.size}") + } else { + val weightsSizeWithoutIntercept = (numClasses - 1) * numFeatures + val weightsSizeWithIntercept = (numClasses - 1) * (numFeatures + 1) + require(weights.size == weightsSizeWithoutIntercept || weights.size == weightsSizeWithIntercept, + s"LogisticRegressionModel.load with numClasses = $numClasses and numFeatures = $numFeatures" + + s" expected weights of length $weightsSizeWithoutIntercept (without intercept)" + + s" or $weightsSizeWithIntercept (with intercept)," + + s" but was given weights of length ${weights.size}") + } + def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) @@ -81,7 +95,9 @@ class LogisticRegressionModel ( this } - override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { require(dataMatrix.size == numFeatures) @@ -140,7 +156,7 @@ class LogisticRegressionModel ( override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, - weights, intercept, threshold) + numFeatures, numClasses, weights, intercept, threshold) } override protected def formatVersion: String = "1.0" @@ -150,11 +166,16 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) - val model = new LogisticRegressionModel(data.weights, data.intercept) + // numFeatures, numClasses, weights are checked in model initialization + val model = + new LogisticRegressionModel(data.weights, data.intercept, numFeatures, numClasses) data.threshold match { case Some(t) => model.setThreshold(t) case None => model.clearThreshold() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 7412759cdb33..c8fe19855dda 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} @@ -78,10 +79,13 @@ class NaiveBayesModel private[mllib] ( object NaiveBayesModel extends Loader[NaiveBayesModel] { + import Loader._ + private object SaveLoadV1_0 { def thisFormatVersion = "1.0" + /** Hard-code class name string in case it changes in the future */ def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel" /** Model data for model import/export */ @@ -93,22 +97,23 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create JSON metadata. val metadataRDD = - sc.parallelize(Seq((thisClassName, thisFormatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1) + .toDataFrame("class", "version", "numFeatures", "numClasses") + metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) // Create Parquet data. - val dataRDD: DataFrame = sc.parallelize(Seq(data)) - dataRDD.repartition(1).saveAsParquetFile(path + "/data") + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + dataRDD.saveAsParquetFile(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(path + "/data") + val dataRDD = sqlContext.parquetFile(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. - Loader.checkSchema[Data](dataRDD.schema) + checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}") + assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -118,11 +123,24 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } override def load(sc: SparkContext, path: String): NaiveBayesModel = { - val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val (loadedClassName, version, metadata) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val model = SaveLoadV1_0.load(sc, path) + assert(model.pi.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class priors vector pi had ${model.pi.size} elements") + assert(model.theta.size == numClasses, + s"NaiveBayesModel.load expected $numClasses classes," + + s" but class conditionals array theta had ${model.theta.size} elements") + assert(model.theta.forall(_.size == numFeatures), + s"NaiveBayesModel.load expected $numFeatures features," + + s" but class conditionals array theta had elements of size:" + + s" ${model.theta.map(_.size).mkString(",")}") + model case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index baa3d27f99b8..24d31e62ba50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -83,7 +83,7 @@ class SVMModel ( override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, - weights, intercept, threshold) + numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) } override protected def formatVersion: String = "1.0" @@ -93,11 +93,18 @@ object SVMModel extends Loader[SVMModel] { override def load(sc: SparkContext, path: String): SVMModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => + val (numFeatures, numClasses) = + ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) val model = new SVMModel(data.weights, data.intercept) + assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + + s" was given non-matching weights vector of size ${model.weights.size}") + assert(numClasses == 2, + s"SVMModel.load was given numClasses=$numClasses but only supports 2 classes") data.threshold match { case Some(t) => model.setThreshold(t) case None => model.clearThreshold() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index eb7f20dcf2a3..b60c0cdd0ab7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.classification.impl import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** @@ -30,13 +31,20 @@ private[classification] object GLMClassificationModel { def thisFormatVersion = "1.0" - /** Model data for model import/export */ + /** Model data for import/export */ case class Data(weights: Vector, intercept: Double, threshold: Option[Double]) + /** + * Helper method for saving GLM classification model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + * @param numClasses Number of classes label can take, to be saved with metadata + */ def save( sc: SparkContext, path: String, modelClass: String, + numFeatures: Int, + numClasses: Int, weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { @@ -45,23 +53,32 @@ private[classification] object GLMClassificationModel { // Create JSON metadata. val metadataRDD = - sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1) + .toDataFrame("class", "version", "numFeatures", "numClasses") + metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val data = Data(weights, intercept, threshold) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(path + "/data") + dataRDD.saveAsParquetFile(Loader.dataPath(path)) } + /** + * Helper method for loading GLM classification model data. + * + * NOTE: Callers of this method should check numClasses, numFeatures on their own. + * + * @param modelClass String name for model class (used for error messages) + */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(path + "/data") + val dataRDD = sqlContext.parquetFile(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) - assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}") + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) - assert(data.size == 3, s"Unable to load $modelClass data from: ${path + "/data"}") + assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") val (weights, intercept) = data match { case Row(weights: Vector, intercept: Double, _) => (weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fd674d5e0701..1159e59fff5f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -54,10 +54,12 @@ object LassoModel extends Loader[LassoModel] { override def load(sc: SparkContext, path: String): LassoModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new LassoModel(data.weights, data.intercept) case _ => throw new Exception( s"LassoModel.load did not recognize model with (className, format version):" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 0333a97bac9a..0136dcfdceae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -54,10 +54,12 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { override def load(sc: SparkContext, path: String): LinearRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new LinearRegressionModel(data.weights, data.intercept) case _ => throw new Exception( s"LinearRegressionModel.load did not recognize model with (className, format version):" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 64b02f7a6e7a..d6bbe7bbf440 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,8 +19,10 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.{DataFrame, Row} @Experimental trait RegressionModel extends Serializable { @@ -48,3 +50,21 @@ trait RegressionModel extends Serializable { def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } + +private[mllib] object RegressionModel { + + /** + * Helper method for loading GLM regression model metadata. + * + * @param modelClass String name for model class (used for error messages) + * @return numFeatures + */ + def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = { + metadata.select("numFeatures").take(1)(0) match { + case Row(nFeatures: Int) => nFeatures + case _ => throw new Exception(s"$modelClass unable to load" + + s" numFeatures from metadata: ${Loader.metadataPath(path)}") + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 6ffe8ec6230c..32a40b9a51d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -55,10 +55,12 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { override def load(sc: SparkContext, path: String): RidgeRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + // Hard-code class name string in case it changes in the future val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) + val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new RidgeRegressionModel(data.weights, data.intercept) case _ => throw new Exception( s"RidgeRegressionModel.load did not recognize model with (className, format version):" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index f5a3ca38f77d..00f25a8be939 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression.impl import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.Loader import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** @@ -33,6 +34,10 @@ private[regression] object GLMRegressionModel { /** Model data for model import/export */ case class Data(weights: Vector, intercept: Double) + /** + * Helper method for saving GLM regression model metadata and data. + * @param modelClass String name for model class, to be saved with metadata + */ def save( sc: SparkContext, path: String, @@ -44,25 +49,35 @@ private[regression] object GLMRegressionModel { // Create JSON metadata. val metadataRDD = - sc.parallelize(Seq((modelClass, thisFormatVersion))).toDataFrame("class", "version") - metadataRDD.toJSON.repartition(1).saveAsTextFile(path + "/metadata") + sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1) + .toDataFrame("class", "version", "numFeatures") + metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val data = Data(weights, intercept) - val dataRDD: DataFrame = sc.parallelize(Seq(data)) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(path + "/data") + dataRDD.saveAsParquetFile(Loader.dataPath(path)) } - def loadData(sc: SparkContext, path: String, modelClass: String): Data = { + /** + * Helper method for loading GLM regression model data. + * @param modelClass String name for model class (used for error messages) + * @param numFeatures Number of features, to be checked against loaded data. + * The length of the weights vector should equal numFeatures. + */ + def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { + val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(path + "/data") + val dataRDD = sqlContext.parquetFile(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) - assert(dataArray.size == 1, s"Unable to load $modelClass data from: ${path + "/data"}") + assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) - assert(data.size == 2, s"Unable to load $modelClass data from: ${path + "/data"}") + assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") data match { case Row(weights: Vector, intercept: Double) => + assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + + s" found ${weights.size} features when loading $modelClass weights from $datapath") Data(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 20d4a9537f3e..56b77a7d12e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.util import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -82,6 +84,12 @@ trait Loader[M <: Saveable] { */ private[mllib] object Loader { + /** Returns URI for path/data using the Hadoop filesystem */ + def dataPath(path: String): String = new Path(path, "data").toUri.toString + + /** Returns URI for path/metadata using the Hadoop filesystem */ + def metadataPath(path: String): String = new Path(path, "metadata").toUri.toString + /** * Check the schema of loaded model data. * @@ -114,7 +122,7 @@ private[mllib] object Loader { */ def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = { val sqlContext = new SQLContext(sc) - val metadata = sqlContext.jsonFile(path + "/metadata") + val metadata = sqlContext.jsonFile(metadataPath(path)) val (clazz, version) = try { val metadataArray = metadata.select("class", "version").take(1) assert(metadataArray.size == 1) @@ -123,7 +131,7 @@ private[mllib] object Loader { } } catch { case e: Exception => - throw new Exception(s"Unable to load model metadata from: ${path + "/metadata"}") + throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}") } (clazz, version, metadata) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index fd33f7c0daa3..6be1b290a9b6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -149,6 +149,14 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** Binary labels, 3 features */ + private val binaryModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5, numFeatures = 3, numClasses = 2) + + /** 3 classes, 2 features */ + private val multiclassModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) } class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { @@ -464,19 +472,9 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } - test("model save/load") { + test("model save/load: binary classification") { // NOTE: This will need to be generalized once there are multiple model format versions. - val nPoints = 20 - val A = 2.0 - val B = -1.5 - - val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val lr = new LogisticRegressionWithLBFGS().setIntercept(true) - lr.optimizer.setNumIterations(1) - val model = lr.run(testRDD) + val model = LogisticRegressionSuite.binaryModel model.clearThreshold() assert(model.getThreshold.isEmpty) @@ -485,19 +483,47 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val path = tempDir.toURI.toString // Save model, load it back, and compare. - model.save(sc, path) - val sameModel = LogisticRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(sameModel.getThreshold.isEmpty) - Utils.deleteRecursively(tempDir) + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(model.numClasses == sameModel.numClasses) + assert(model.numFeatures == sameModel.numFeatures) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } // Save model with threshold. - model.setThreshold(0.7) - model.save(sc, path) - val sameModel2 = LogisticRegressionModel.load(sc, path) - assert(model.getThreshold.get == sameModel2.getThreshold.get) - Utils.deleteRecursively(tempDir) + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = LogisticRegressionModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load: multiclass classification") { + // NOTE: This will need to be generalized once there are multiple model format versions. + val model = LogisticRegressionSuite.multiclassModel + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LogisticRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(model.numClasses == sameModel.numClasses) + assert(model.numFeatures == sameModel.numFeatures) + } finally { + Utils.deleteRecursively(tempDir) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index a86d7d5131e1..64dcc0fb9f82 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -68,6 +68,10 @@ object NaiveBayesSuite { Array(0.03, 0.91, 0.03, 0.03), // label 1 Array(0.03, 0.03, 0.91, 0.03) // label 2 ).map(_.map(math.log)) + + /** Binary labels, 3 features */ + private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8), + theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4))) } class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { @@ -131,27 +135,21 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } test("model save/load") { - val nPoints = 10 - - val pi = NaiveBayesSuite.smallPi - val theta = NaiveBayesSuite.smallTheta - - val data = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) - val rdd = sc.parallelize(data, 2) - rdd.cache() - - val model = NaiveBayes.train(rdd) + val model = NaiveBayesSuite.binaryModel val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString // Save model, load it back, and compare. - model.save(sc, path) - val sameModel = NaiveBayesModel.load(sc, path) - assert(model.labels === sameModel.labels) - assert(model.pi === sameModel.pi) - assert(model.theta === sameModel.theta) - Utils.deleteRecursively(tempDir) + try { + model.save(sc, path) + val sameModel = NaiveBayesModel.load(sc, path) + assert(model.labels === sameModel.labels) + assert(model.pi === sameModel.pi) + assert(model.theta === sameModel.theta) + } finally { + Utils.deleteRecursively(tempDir) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 41d8d8aa682e..6de098b383ba 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -57,6 +57,9 @@ object SVMSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + /** Binary labels, 3 features */ + private val binaryModel = new SVMModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + } class SVMSuite extends FunSuite with MLlibTestSparkContext { @@ -195,18 +198,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { test("model save/load") { // NOTE: This will need to be generalized once there are multiple model format versions. - val nPoints = 10 - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val data = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - val rdd = sc.parallelize(data, 2) - rdd.cache() - - val svm = new SVMWithSGD() - svm.optimizer.setNumIterations(1) - val model = svm.run(rdd) + val model = SVMSuite.binaryModel model.clearThreshold() assert(model.getThreshold.isEmpty) @@ -215,19 +207,25 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val path = tempDir.toURI.toString // Save model, load it back, and compare. - model.save(sc, path) - val sameModel = SVMModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(sameModel.getThreshold.isEmpty) - Utils.deleteRecursively(tempDir) + try { + model.save(sc, path) + val sameModel = SVMModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + assert(sameModel.getThreshold.isEmpty) + } finally { + Utils.deleteRecursively(tempDir) + } // Save model with threshold. - model.setThreshold(0.7) - model.save(sc, path) - val sameModel2 = SVMModel.load(sc, path) - assert(model.getThreshold.get == sameModel2.getThreshold.get) - Utils.deleteRecursively(tempDir) + try { + model.setThreshold(0.7) + model.save(sc, path) + val sameModel2 = SVMModel.load(sc, path) + assert(model.getThreshold.get == sameModel2.getThreshold.get) + } finally { + Utils.deleteRecursively(tempDir) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2c1153ef0599..c9f5dc069ef2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -26,6 +26,12 @@ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerato MLlibTestSparkContext} import org.apache.spark.util.Utils +private object LassoSuite { + + /** 3 features */ + val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + class LassoSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { @@ -118,27 +124,20 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } test("model save/load") { - // Create dataset - val nPoints = 10 - val A = 2.0 - val B = -1.5 - val C = 1.0e-2 - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) - val testRDD = sc.parallelize(testData, 2).cache() + val model = LassoSuite.model - // Train model - val ls = new LassoWithSGD() - ls.optimizer.setNumIterations(1) - val model = ls.run(testRDD) - - // Save model, load it back, and compare. val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - model.save(sc, path) - val sameModel = LassoModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - Utils.deleteRecursively(tempDir) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = LassoModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index bbf4a1226fe5..3781931c2f81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -26,6 +26,12 @@ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerato MLlibTestSparkContext} import org.apache.spark.util.Utils +private object LinearRegressionSuite { + + /** 3 features */ + val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { @@ -127,24 +133,20 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } test("model save/load") { - // Create dataset - val rdd = sc.parallelize( - LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 10, 42), 2) - - // Train model - val linReg = new LinearRegressionWithSGD().setIntercept(false) - linReg.optimizer.setNumIterations(1) - val model = linReg.run(rdd) + val model = LinearRegressionSuite.model val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString // Save model, load it back, and compare. - model.save(sc, path) - val sameModel = LinearRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - Utils.deleteRecursively(tempDir) + try { + model.save(sc, path) + val sameModel = LinearRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 6ed36fe3f7d7..43d61151e247 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -27,6 +27,12 @@ import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerato MLlibTestSparkContext} import org.apache.spark.util.Utils +private object RidgeRegressionSuite { + + /** 3 features */ + val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) +} + class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { @@ -78,26 +84,20 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } test("model save/load") { - // Create dataset - val numExamples = 20 - val numFeatures = 4 - val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, numExamples, 42, 10.0) - val rdd = sc.parallelize(data, 2).cache() + val model = RidgeRegressionSuite.model - // Train model - val lr = new RidgeRegressionWithSGD() - lr.optimizer.setNumIterations(1) - val model = lr.run(rdd) - - // Save model, load it back, and compare. val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString - model.save(sc, path) - val sameModel = RidgeRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - Utils.deleteRecursively(tempDir) + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = RidgeRegressionModel.load(sc, path) + assert(model.weights == sameModel.weights) + assert(model.intercept == sameModel.intercept) + } finally { + Utils.deleteRecursively(tempDir) + } } } From 87c4eb8e2e030cf033418901fd7c0533efb090f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 4 Feb 2015 15:32:29 -0800 Subject: [PATCH 14/14] small cleanups --- .../mllib/classification/NaiveBayes.scala | 1 - .../mllib/regression/RegressionModel.scala | 2 +- .../mllib/regression/RidgeRegression.scala | 2 +- .../LogisticRegressionSuite.scala | 24 ++++++++++--------- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c8fe19855dda..4bafd495f90b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.{SparkContext, SparkException, Logging} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index d6bbe7bbf440..843e59bdfbdd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 32a40b9a51d8..f2a5f1db1ece 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 6be1b290a9b6..d2b40f2cae02 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -157,8 +157,17 @@ object LogisticRegressionSuite { /** 3 classes, 2 features */ private val multiclassModel = new LogisticRegressionModel( weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) + + private def checkModelsEqual(a: LogisticRegressionModel, b: LogisticRegressionModel): Unit = { + assert(a.weights == b.weights) + assert(a.intercept == b.intercept) + assert(a.numClasses == b.numClasses) + assert(a.numFeatures == b.numFeatures) + assert(a.getThreshold == b.getThreshold) + } } + class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], @@ -486,11 +495,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(model.numClasses == sameModel.numClasses) - assert(model.numFeatures == sameModel.numFeatures) - assert(sameModel.getThreshold.isEmpty) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) } @@ -499,8 +504,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.setThreshold(0.7) model.save(sc, path) - val sameModel2 = LogisticRegressionModel.load(sc, path) - assert(model.getThreshold.get == sameModel2.getThreshold.get) + val sameModel = LogisticRegressionModel.load(sc, path) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) } @@ -517,10 +522,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M try { model.save(sc, path) val sameModel = LogisticRegressionModel.load(sc, path) - assert(model.weights == sameModel.weights) - assert(model.intercept == sameModel.intercept) - assert(model.numClasses == sameModel.numClasses) - assert(model.numFeatures == sameModel.numFeatures) + LogisticRegressionSuite.checkModelsEqual(model, sameModel) } finally { Utils.deleteRecursively(tempDir) }