From 63eff48a945d37c29e0385c263189cc1ee0a2104 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Mar 2015 15:15:26 -0700 Subject: [PATCH 1/2] Added getLambda to Scala NaiveBayes --- .../apache/spark/mllib/classification/NaiveBayes.scala | 3 +++ .../spark/mllib/classification/NaiveBayesSuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b11fd4f128c56..2ebc7fa5d4234 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -166,6 +166,9 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } + /** Get the smoothing parameter. Default: 1.0. */ + def getLambda: Double = lambda + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 64dcc0fb9f82c..2504b615cd7fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -85,6 +85,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + test("get, set params") { + val nb = new NaiveBayes() + nb.setLambda(2.0) + assert(nb.getLambda == 2.0) + nb.setLambda(3.0) + assert(nb.getLambda == 3.0) + } + test("Naive Bayes") { val nPoints = 10000 From a471d907b120357ecf576ab44982fc2135041389 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 13 Mar 2015 01:55:07 -0700 Subject: [PATCH 2/2] small edits from review --- .../apache/spark/mllib/classification/NaiveBayesSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 2504b615cd7fc..5a27c7d2309c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -88,9 +88,9 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) - assert(nb.getLambda == 2.0) + assert(nb.getLambda === 2.0) nb.setLambda(3.0) - assert(nb.getLambda == 3.0) + assert(nb.getLambda === 3.0) } test("Naive Bayes") {