|
18 | 18 | package org.apache.spark.mllib.classification |
19 | 19 |
|
20 | 20 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} |
21 | | -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels |
| 21 | +import breeze.numerics.{exp => brzExp, log => brzLog} |
22 | 22 |
|
23 | 23 | import org.apache.spark.{SparkException, Logging} |
24 | 24 | import org.apache.spark.SparkContext._ |
25 | 25 | import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} |
26 | 26 | import org.apache.spark.mllib.regression.LabeledPoint |
| 27 | +import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels |
27 | 28 | import org.apache.spark.rdd.RDD |
28 | 29 |
|
29 | 30 |
|
@@ -52,29 +53,14 @@ class NaiveBayesModel private[mllib] ( |
52 | 53 | val theta: Array[Array[Double]], |
53 | 54 | val model: NaiveBayesModels) extends ClassificationModel with Serializable { |
54 | 55 |
|
55 | | - def populateMatrix(arrayIn: Array[Array[Double]], |
56 | | - matrixIn: BDM[Double], |
57 | | - transformation: (Double) => Double = (x) => x) = { |
58 | | - var i = 0 |
59 | | - while (i < arrayIn.length) { |
60 | | - var j = 0 |
61 | | - while (j < arrayIn(i).length) { |
62 | | - matrixIn(i, j) = transformation(theta(i)(j)) |
63 | | - j += 1 |
64 | | - } |
65 | | - i += 1 |
66 | | - } |
67 | | - } |
68 | | - |
69 | 56 | private val brzPi = new BDV[Double](pi) |
70 | | - private val brzTheta = new BDM[Double](theta.length, theta(0).length) |
71 | | - populateMatrix(theta, brzTheta) |
| 57 | + private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t |
72 | 58 |
|
73 | 59 | private val brzNegTheta: Option[BDM[Double]] = model match { |
74 | 60 | case NaiveBayesModels.Multinomial => None |
75 | 61 | case NaiveBayesModels.Bernoulli => |
76 | | - val negTheta = new BDM[Double](theta.length, theta(0).length) |
77 | | - populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x))) |
| 62 | + val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) |
| 63 | + //((x) => math.log(1.0 - math.exp(x)) |
78 | 64 | Option(negTheta) |
79 | 65 | } |
80 | 66 |
|
@@ -244,7 +230,7 @@ object NaiveBayes { |
244 | 230 | * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be |
245 | 231 | * Multinomial or Bernoulli |
246 | 232 | */ |
247 | | - def train(input: RDD[LabeledPoint], lambda: Double, model: NaiveBayesModels): NaiveBayesModel = { |
248 | | - new NaiveBayes(lambda, model).run(input) |
| 233 | + def train(input: RDD[LabeledPoint], lambda: Double, model: String): NaiveBayesModel = { |
| 234 | + new NaiveBayes(lambda, NaiveBayesModels.withName(model)).run(input) |
249 | 235 | } |
250 | 236 | } |
0 commit comments