From 04f0d3c6732ce503de95c0b3e8bcf87f16767877 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 5 May 2015 10:40:09 -0700 Subject: [PATCH 01/11] Added stats from cross validation as a val in the cross validation model to save them for user access --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 5 +++-- .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ac0d1fed84b2e..c352f3c6f6372 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -136,7 +136,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(this, bestModel)) + copyValues(new CrossValidatorModel(this, bestModel, metrics)) } override def transformSchema(schema: StructType): StructType = { @@ -151,7 +151,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP @AlphaComponent class CrossValidatorModel private[ml] ( override val parent: CrossValidator, - val bestModel: Model[_]) + val bestModel: Model[_], + val crossValidationMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(paramMap: ParamMap): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440fbf6..651b2efe2c77e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -52,5 +52,6 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.crossValidationMetrics.length == 4) } } From 58d060b518133b1e64ef86ca7aee61b76d6c6990 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Thu, 7 May 2015 18:16:27 -0700 Subject: [PATCH 02/11] changed param name and test according to comments --- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- .../scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c352f3c6f6372..4be424929324e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -152,7 +152,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP class CrossValidatorModel private[ml] ( override val parent: CrossValidator, val bestModel: Model[_], - val crossValidationMetrics: Array[Double]) + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(paramMap: ParamMap): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 651b2efe2c77e..9b577b3c1bd4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -52,6 +52,6 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.crossValidationMetrics.length == 4) + assert(cvModel.crossValidationMetrics.length == lrParamMaps.length) } } From f191c71afcfe1b9a0d989669c152fad58d4bab89 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Thu, 7 May 2015 18:20:55 -0700 Subject: [PATCH 03/11] fixed name --- .../scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 9b577b3c1bd4b..1a38adf0c7536 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -52,6 +52,6 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.crossValidationMetrics.length == lrParamMaps.length) + assert(cvModel.avgMetrics.length == lrParamMaps.length) } } From 67253f08cdf97a32c7caf2c6e65fee495e218aad Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 11 May 2015 20:52:53 -0700 Subject: [PATCH 04/11] added check to bernoulli to ensure feature values are zero or one --- .../mllib/classification/NaiveBayes.scala | 10 +++++-- .../classification/NaiveBayesSuite.scala | 30 +++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c9b3ff0172e2e..da1240c292ce4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -87,12 +87,15 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { + val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + labels (brzArgmax (brzPi + brzTheta * brzData) ) case "Bernoulli" => + brzData.foreach(v => if (!(v == 0.0 || v == 1.0)) + throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")) labels (brzArgmax (brzPi + - (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") @@ -291,6 +294,9 @@ class NaiveBayes private ( if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") } + if (modelType == "Bernoulli" && (!values.forall(v => v == 0.0 || v == 1.0) )) { + throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } } // Aggregates term frequencies per label. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea89b17b7c08f..5322aed029eb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -208,6 +208,36 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("detect non zero or one values in Bernoulli") { + val bad = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(bad, 2), 1.0, "Bernoulli") + } + + val okTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + val badPredict = Seq( + Vectors.dense(1.0), + Vectors.dense(2.0), + Vectors.dense(1.0), + Vectors.dense(0.0)) + + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") + + intercept[SparkException] { + model.predict(sc.makeRDD(badPredict, 2)) + } + } + test("model save/load: 2.0 to 2.0") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString From f44bb3c39c0d73e7d8a67a6e79f6bd741cdb0425 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 11 May 2015 21:07:00 -0700 Subject: [PATCH 05/11] removed changes from CV branch --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 5 ++--- .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4be424929324e..ac0d1fed84b2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -136,7 +136,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(this, bestModel, metrics)) + copyValues(new CrossValidatorModel(this, bestModel)) } override def transformSchema(schema: StructType): StructType = { @@ -151,8 +151,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP @AlphaComponent class CrossValidatorModel private[ml] ( override val parent: CrossValidator, - val bestModel: Model[_], - val avgMetrics: Array[Double]) + val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(paramMap: ParamMap): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 1a38adf0c7536..05313d440fbf6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -52,6 +52,5 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.avgMetrics.length == lrParamMaps.length) } } From 831fd279e16a97711b30346c19a1dcde16728f19 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 11 May 2015 22:28:51 -0700 Subject: [PATCH 06/11] got test working --- .../apache/spark/mllib/classification/NaiveBayesSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5322aed029eb7..945066cfcce01 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -232,9 +232,8 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(0.0)) val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") - intercept[SparkException] { - model.predict(sc.makeRDD(badPredict, 2)) + model.predict(sc.makeRDD(badPredict, 2)).collect() } } From 3f3b32cb9514da5642a38ca37737d7cb237d6799 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 12 May 2015 10:47:09 -0700 Subject: [PATCH 07/11] fixed zero one check so only called in combiner --- .../mllib/classification/NaiveBayes.scala | 22 +++++++++++++++---- .../classification/NaiveBayesSuite.scala | 10 ++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index da1240c292ce4..514fa91bba642 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -92,8 +92,10 @@ class NaiveBayesModel private[mllib] ( case "Multinomial" => labels (brzArgmax (brzPi + brzTheta * brzData) ) case "Bernoulli" => - brzData.foreach(v => if (!(v == 0.0 || v == 1.0)) - throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")) + if (!brzData.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + } labels (brzArgmax (brzPi + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => @@ -294,8 +296,19 @@ class NaiveBayes private ( if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") } - if (modelType == "Bernoulli" && (!values.forall(v => v == 0.0 || v == 1.0) )) { - throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } + + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + if (modelType == "Bernoulli") { + val values = v match { + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } } } @@ -305,6 +318,7 @@ class NaiveBayes private ( val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { requireNonnegativeValues(v) + requireZeroOneBernoulliValues(v) (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 945066cfcce01..40a79a1f19bd9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -209,21 +209,25 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } test("detect non zero or one values in Bernoulli") { - val bad = Seq( + val badTrain = Seq( LabeledPoint(1.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(1.0, Vectors.dense(1.0)), LabeledPoint(1.0, Vectors.dense(0.0))) intercept[SparkException] { - NaiveBayes.train(sc.makeRDD(bad, 2), 1.0, "Bernoulli") + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") } val okTrain = Seq( LabeledPoint(1.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(0.0)), LabeledPoint(1.0, Vectors.dense(1.0)), - LabeledPoint(1.0, Vectors.dense(0.0))) + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)) + ) val badPredict = Seq( Vectors.dense(1.0), From 9ee9e843825eb3d8507fa320d62ab46fe058853e Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 12 May 2015 13:31:33 -0700 Subject: [PATCH 08/11] fixed style error --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 514fa91bba642..aaa6b219ace5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -307,7 +307,8 @@ class NaiveBayes private ( values } if (!values.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException(s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") } } } From 4eedf1e3e6b8ed1763861ad7e75a0db8542c4c78 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 12 May 2015 17:31:46 -0700 Subject: [PATCH 09/11] moved bernoulli check --- .../mllib/classification/NaiveBayes.scala | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index aaa6b219ace5b..ead1747f2e14e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -46,10 +46,10 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], - val modelType: String) + val labels: Array[Double], + val pi: Array[Double], + val theta: Array[Array[Double]], + val modelType: String) extends ClassificationModel with Serializable with Saveable { private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = @@ -57,9 +57,9 @@ class NaiveBayesModel private[mllib] ( /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( - labels: JIterable[Double], - pi: JIterable[Double], - theta: JIterable[JIterable[Double]]) = + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) @@ -125,10 +125,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { /** Model data for model import/export */ case class Data( - labels: Array[Double], - pi: Array[Double], - theta: Array[Array[Double]], - modelType: String) + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -172,9 +172,9 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { /** Model data for model import/export */ case class Data( - labels: Array[Double], - pi: Array[Double], - theta: Array[Array[Double]]) + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -222,8 +222,8 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $version). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } assert(model.pi.size == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + @@ -249,8 +249,8 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { */ class NaiveBayes private ( - private var lambda: Double, - private var modelType: String) extends Serializable with Logging { + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { def this(lambda: Double) = this(lambda, "Multinomial") @@ -299,17 +299,15 @@ class NaiveBayes private ( } val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - if (modelType == "Bernoulli") { - val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values - } - if (!values.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") - } + val values = v match { + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") } } @@ -319,7 +317,7 @@ class NaiveBayes private ( val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { requireNonnegativeValues(v) - requireZeroOneBernoulliValues(v) + if (modelType == "Bernoulli") requireZeroOneBernoulliValues(v) (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => { From 911bf83fe9947f8ea1b844a68db44d6784299bb3 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 12 May 2015 17:34:54 -0700 Subject: [PATCH 10/11] undid reformat --- .../mllib/classification/NaiveBayes.scala | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index ead1747f2e14e..ac38fa5a71b73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -46,10 +46,10 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" */ class NaiveBayesModel private[mllib] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], - val modelType: String) + val labels: Array[Double], + val pi: Array[Double], + val theta: Array[Array[Double]], + val modelType: String) extends ClassificationModel with Serializable with Saveable { private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = @@ -57,9 +57,9 @@ class NaiveBayesModel private[mllib] ( /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( - labels: JIterable[Double], - pi: JIterable[Double], - theta: JIterable[JIterable[Double]]) = + labels: JIterable[Double], + pi: JIterable[Double], + theta: JIterable[JIterable[Double]]) = this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) private val brzPi = new BDV[Double](pi) @@ -125,10 +125,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { /** Model data for model import/export */ case class Data( - labels: Array[Double], - pi: Array[Double], - theta: Array[Array[Double]], - modelType: String) + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]], + modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -172,9 +172,9 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { /** Model data for model import/export */ case class Data( - labels: Array[Double], - pi: Array[Double], - theta: Array[Array[Double]]) + labels: Array[Double], + pi: Array[Double], + theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { val sqlContext = new SQLContext(sc) @@ -222,8 +222,8 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { (model, numFeatures, numClasses) case _ => throw new Exception( s"NaiveBayesModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $version). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") } assert(model.pi.size == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + @@ -249,8 +249,8 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { */ class NaiveBayes private ( - private var lambda: Double, - private var modelType: String) extends Serializable with Logging { + private var lambda: Double, + private var modelType: String) extends Serializable with Logging { def this(lambda: Double) = this(lambda, "Multinomial") From b8442c290db565cc73f3fd7c2581af8e4e067140 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Tue, 12 May 2015 19:42:27 -0700 Subject: [PATCH 11/11] changed to if else for value checks --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index ac38fa5a71b73..b381dc2cb0140 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -316,8 +316,11 @@ class NaiveBayes private ( // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) - if (modelType == "Bernoulli") requireZeroOneBernoulliValues(v) + if (modelType == "Bernoulli") { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => {