diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b11fd4f128c56..2ebc7fa5d4234 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -166,6 +166,9 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with this } + /** Get the smoothing parameter. Default: 1.0. */ + def getLambda: Double = lambda + /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 64dcc0fb9f82c..5a27c7d2309c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -85,6 +85,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(numOfPredictions < input.length / 5) } + test("get, set params") { + val nb = new NaiveBayes() + nb.setLambda(2.0) + assert(nb.getLambda === 2.0) + nb.setLambda(3.0) + assert(nb.getLambda === 3.0) + } + test("Naive Bayes") { val nPoints = 10000