Skip to content

Commit dc4abd4

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-6252] [mllib] Added getLambda to Scala NaiveBayes
Note: not relevant for Python API since it only has a static train method Author: Joseph K. Bradley <[email protected]> Author: Joseph K. Bradley <[email protected]> Closes #4969 from jkbradley/SPARK-6252 and squashes the following commits: a471d90 [Joseph K. Bradley] small edits from review 63eff48 [Joseph K. Bradley] Added getLambda to Scala NaiveBayes
1 parent ea3d2ee commit dc4abd4

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
166166
this
167167
}
168168

169+
/** Get the smoothing parameter. Default: 1.0. */
170+
def getLambda: Double = lambda
171+
169172
/**
170173
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
171174
*

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
8585
assert(numOfPredictions < input.length / 5)
8686
}
8787

88+
test("get, set params") {
89+
val nb = new NaiveBayes()
90+
nb.setLambda(2.0)
91+
assert(nb.getLambda === 2.0)
92+
nb.setLambda(3.0)
93+
assert(nb.getLambda === 3.0)
94+
}
95+
8896
test("Naive Bayes") {
8997
val nPoints = 10000
9098

0 commit comments

Comments
 (0)