Skip to content

Commit bea62af

Browse files
committed
put back in constructor for NaiveBayes
1 parent 01baad7 commit bea62af

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ class NaiveBayes private (
186186
private var lambda: Double,
187187
private var modelType: NaiveBayes.ModelType) extends Serializable with Logging {
188188

189+
private def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
190+
189191
def this() = this(1.0, NaiveBayes.Multinomial)
190192

191193
/** Set the smoothing parameter. Default: 1.0. */

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
146146
).map(_.map(math.log))
147147

148148
val testData = NaiveBayesSuite.generateNaiveBayesInput(
149-
pi,
150-
theta,
151-
nPoints,
152-
45,
153-
NaiveBayes.Bernoulli)
149+
pi, theta, nPoints, 45, NaiveBayes.Bernoulli)
154150
val testRDD = sc.parallelize(testData, 2)
155151
testRDD.cache()
156152

157153
val model = NaiveBayes.train(testRDD, 1.0, "bernoulli")
158154
validateModelFit(pi, theta, model)
159155

160156
val validationData = NaiveBayesSuite.generateNaiveBayesInput(
161-
pi,
162-
theta,
163-
nPoints,
164-
20,
165-
NaiveBayes.Bernoulli)
157+
pi, theta, nPoints, 20, NaiveBayes.Bernoulli)
166158
val validationRDD = sc.parallelize(validationData, 2)
167159

168160
// Test prediction on RDD.

0 commit comments

Comments
 (0)